1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Vector dialect to SPIRV dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 14 15 #include "../PassDetail.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/IR/BuiltinAttributes.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include <numeric> 25 26 using namespace mlir; 27 28 /// Gets the first integer value from `attr`, assuming it is an integer array 29 /// attribute. 30 static uint64_t getFirstIntValue(ArrayAttr attr) { 31 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 32 } 33 34 namespace { 35 36 struct VectorBitcastConvert final 37 : public OpConversionPattern<vector::BitCastOp> { 38 using OpConversionPattern::OpConversionPattern; 39 40 LogicalResult 41 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor, 42 ConversionPatternRewriter &rewriter) const override { 43 auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); 44 if (!dstType) 45 return failure(); 46 47 if (dstType == adaptor.getSource().getType()) 48 rewriter.replaceOp(bitcastOp, adaptor.getSource()); 49 else 50 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType, 51 adaptor.getSource()); 52 53 return success(); 54 } 55 }; 56 57 struct VectorBroadcastConvert final 58 : public OpConversionPattern<vector::BroadcastOp> { 59 using OpConversionPattern::OpConversionPattern; 60 61 LogicalResult 62 matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, 63 ConversionPatternRewriter &rewriter) const override { 64 if (broadcastOp.getSource().getType().isa<VectorType>() || 65 !spirv::CompositeType::isValid(broadcastOp.getVectorType())) 66 return failure(); 67 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(), 68 adaptor.getSource()); 69 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 70 broadcastOp, broadcastOp.getVectorType(), source); 71 return success(); 72 } 73 }; 74 75 struct VectorExtractOpConvert final 76 : public OpConversionPattern<vector::ExtractOp> { 77 using OpConversionPattern::OpConversionPattern; 78 79 LogicalResult 80 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 81 ConversionPatternRewriter &rewriter) const override { 82 // Only support extracting a scalar value now. 83 VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>(); 84 if (resultVectorType && resultVectorType.getNumElements() > 1) 85 return failure(); 86 87 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 88 if (!dstType) 89 return failure(); 90 91 if (adaptor.getVector().getType().isa<spirv::ScalarType>()) { 92 rewriter.replaceOp(extractOp, adaptor.getVector()); 93 return success(); 94 } 95 96 int32_t id = getFirstIntValue(extractOp.getPosition()); 97 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 98 extractOp, adaptor.getVector(), id); 99 return success(); 100 } 101 }; 102 103 struct VectorExtractStridedSliceOpConvert final 104 : public OpConversionPattern<vector::ExtractStridedSliceOp> { 105 using OpConversionPattern::OpConversionPattern; 106 107 LogicalResult 108 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, 109 ConversionPatternRewriter &rewriter) const override { 110 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 111 if (!dstType) 112 return failure(); 113 114 uint64_t offset = getFirstIntValue(extractOp.getOffsets()); 115 uint64_t size = getFirstIntValue(extractOp.getSizes()); 116 uint64_t stride = getFirstIntValue(extractOp.getStrides()); 117 if (stride != 1) 118 return failure(); 119 120 Value srcVector = adaptor.getOperands().front(); 121 122 // Extract vector<1xT> case. 123 if (dstType.isa<spirv::ScalarType>()) { 124 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp, 125 srcVector, offset); 126 return success(); 127 } 128 129 SmallVector<int32_t, 2> indices(size); 130 std::iota(indices.begin(), indices.end(), offset); 131 132 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 133 extractOp, dstType, srcVector, srcVector, 134 rewriter.getI32ArrayAttr(indices)); 135 136 return success(); 137 } 138 }; 139 140 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> { 141 using OpConversionPattern::OpConversionPattern; 142 143 LogicalResult 144 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 145 ConversionPatternRewriter &rewriter) const override { 146 if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) 147 return failure(); 148 rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>( 149 fmaOp, fmaOp.getType(), adaptor.getLhs(), adaptor.getRhs(), 150 adaptor.getAcc()); 151 return success(); 152 } 153 }; 154 155 struct VectorInsertOpConvert final 156 : public OpConversionPattern<vector::InsertOp> { 157 using OpConversionPattern::OpConversionPattern; 158 159 LogicalResult 160 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 161 ConversionPatternRewriter &rewriter) const override { 162 // Special case for inserting scalar values into size-1 vectors. 163 if (insertOp.getSourceType().isIntOrFloat() && 164 insertOp.getDestVectorType().getNumElements() == 1) { 165 rewriter.replaceOp(insertOp, adaptor.getSource()); 166 return success(); 167 } 168 169 if (insertOp.getSourceType().isa<VectorType>() || 170 !spirv::CompositeType::isValid(insertOp.getDestVectorType())) 171 return failure(); 172 int32_t id = getFirstIntValue(insertOp.getPosition()); 173 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 174 insertOp, adaptor.getSource(), adaptor.getDest(), id); 175 return success(); 176 } 177 }; 178 179 struct VectorExtractElementOpConvert final 180 : public OpConversionPattern<vector::ExtractElementOp> { 181 using OpConversionPattern::OpConversionPattern; 182 183 LogicalResult 184 matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor, 185 ConversionPatternRewriter &rewriter) const override { 186 if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) 187 return failure(); 188 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( 189 extractElementOp, extractElementOp.getType(), adaptor.getVector(), 190 extractElementOp.getPosition()); 191 return success(); 192 } 193 }; 194 195 struct VectorInsertElementOpConvert final 196 : public OpConversionPattern<vector::InsertElementOp> { 197 using OpConversionPattern::OpConversionPattern; 198 199 LogicalResult 200 matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor, 201 ConversionPatternRewriter &rewriter) const override { 202 if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) 203 return failure(); 204 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( 205 insertElementOp, insertElementOp.getType(), insertElementOp.getDest(), 206 adaptor.getSource(), insertElementOp.getPosition()); 207 return success(); 208 } 209 }; 210 211 struct VectorInsertStridedSliceOpConvert final 212 : public OpConversionPattern<vector::InsertStridedSliceOp> { 213 using OpConversionPattern::OpConversionPattern; 214 215 LogicalResult 216 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, 217 ConversionPatternRewriter &rewriter) const override { 218 Value srcVector = adaptor.getOperands().front(); 219 Value dstVector = adaptor.getOperands().back(); 220 221 uint64_t stride = getFirstIntValue(insertOp.getStrides()); 222 if (stride != 1) 223 return failure(); 224 uint64_t offset = getFirstIntValue(insertOp.getOffsets()); 225 226 if (srcVector.getType().isa<spirv::ScalarType>()) { 227 assert(!dstVector.getType().isa<spirv::ScalarType>()); 228 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 229 insertOp, dstVector.getType(), srcVector, dstVector, 230 rewriter.getI32ArrayAttr(offset)); 231 return success(); 232 } 233 234 uint64_t totalSize = 235 dstVector.getType().cast<VectorType>().getNumElements(); 236 uint64_t insertSize = 237 srcVector.getType().cast<VectorType>().getNumElements(); 238 239 SmallVector<int32_t, 2> indices(totalSize); 240 std::iota(indices.begin(), indices.end(), 0); 241 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize, 242 totalSize); 243 244 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 245 insertOp, dstVector.getType(), dstVector, srcVector, 246 rewriter.getI32ArrayAttr(indices)); 247 248 return success(); 249 } 250 }; 251 252 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> { 253 public: 254 using OpConversionPattern<vector::SplatOp>::OpConversionPattern; 255 256 LogicalResult 257 matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor, 258 ConversionPatternRewriter &rewriter) const override { 259 VectorType dstVecType = op.getType(); 260 if (!spirv::CompositeType::isValid(dstVecType)) 261 return failure(); 262 SmallVector<Value, 4> source(dstVecType.getNumElements(), 263 adaptor.getInput()); 264 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType, 265 source); 266 return success(); 267 } 268 }; 269 270 struct VectorShuffleOpConvert final 271 : public OpConversionPattern<vector::ShuffleOp> { 272 using OpConversionPattern::OpConversionPattern; 273 274 LogicalResult 275 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 276 ConversionPatternRewriter &rewriter) const override { 277 auto oldResultType = shuffleOp.getVectorType(); 278 if (!spirv::CompositeType::isValid(oldResultType)) 279 return failure(); 280 auto newResultType = getTypeConverter()->convertType(oldResultType); 281 282 auto oldSourceType = shuffleOp.getV1VectorType(); 283 if (oldSourceType.getNumElements() > 1) { 284 SmallVector<int32_t, 4> components = llvm::to_vector<4>( 285 llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t { 286 return attr.cast<IntegerAttr>().getValue().getZExtValue(); 287 })); 288 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 289 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(), 290 rewriter.getI32ArrayAttr(components)); 291 return success(); 292 } 293 294 SmallVector<Value, 2> oldOperands = {adaptor.getV1(), adaptor.getV2()}; 295 SmallVector<Value, 4> newOperands; 296 newOperands.reserve(oldResultType.getNumElements()); 297 for (const APInt &i : shuffleOp.getMask().getAsValueRange<IntegerAttr>()) { 298 newOperands.push_back(oldOperands[i.getZExtValue()]); 299 } 300 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 301 shuffleOp, newResultType, newOperands); 302 303 return success(); 304 } 305 }; 306 307 } // namespace 308 309 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 310 RewritePatternSet &patterns) { 311 patterns.add<VectorBitcastConvert, VectorBroadcastConvert, 312 VectorExtractElementOpConvert, VectorExtractOpConvert, 313 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, 314 VectorInsertElementOpConvert, VectorInsertOpConvert, 315 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, 316 VectorSplatPattern>(typeConverter, patterns.getContext()); 317 } 318