1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Vector dialect to SPIRV dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 14 15 #include "../PassDetail.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/IR/BuiltinAttributes.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include <numeric> 25 26 using namespace mlir; 27 28 /// Gets the first integer value from `attr`, assuming it is an integer array 29 /// attribute. 30 static uint64_t getFirstIntValue(ArrayAttr attr) { 31 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 32 } 33 34 namespace { 35 36 struct VectorBitcastConvert final 37 : public OpConversionPattern<vector::BitCastOp> { 38 using OpConversionPattern::OpConversionPattern; 39 40 LogicalResult 41 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor, 42 ConversionPatternRewriter &rewriter) const override { 43 auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); 44 if (!dstType) 45 return failure(); 46 47 if (dstType == adaptor.source().getType()) 48 rewriter.replaceOp(bitcastOp, adaptor.source()); 49 else 50 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType, 51 adaptor.source()); 52 53 return success(); 54 } 55 }; 56 57 struct VectorBroadcastConvert final 58 : public OpConversionPattern<vector::BroadcastOp> { 59 using OpConversionPattern::OpConversionPattern; 60 61 LogicalResult 62 matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, 63 ConversionPatternRewriter &rewriter) const override { 64 if (broadcastOp.source().getType().isa<VectorType>() || 65 !spirv::CompositeType::isValid(broadcastOp.getVectorType())) 66 return failure(); 67 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(), 68 adaptor.source()); 69 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 70 broadcastOp, broadcastOp.getVectorType(), source); 71 return success(); 72 } 73 }; 74 75 struct VectorExtractOpConvert final 76 : public OpConversionPattern<vector::ExtractOp> { 77 using OpConversionPattern::OpConversionPattern; 78 79 LogicalResult 80 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 81 ConversionPatternRewriter &rewriter) const override { 82 // Only support extracting a scalar value now. 83 VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>(); 84 if (resultVectorType && resultVectorType.getNumElements() > 1) 85 return failure(); 86 87 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 88 if (!dstType) 89 return failure(); 90 91 if (adaptor.vector().getType().isa<spirv::ScalarType>()) { 92 rewriter.replaceOp(extractOp, adaptor.vector()); 93 return success(); 94 } 95 96 int32_t id = getFirstIntValue(extractOp.position()); 97 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 98 extractOp, adaptor.vector(), id); 99 return success(); 100 } 101 }; 102 103 struct VectorExtractStridedSliceOpConvert final 104 : public OpConversionPattern<vector::ExtractStridedSliceOp> { 105 using OpConversionPattern::OpConversionPattern; 106 107 LogicalResult 108 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, 109 ConversionPatternRewriter &rewriter) const override { 110 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 111 if (!dstType) 112 return failure(); 113 114 115 uint64_t offset = getFirstIntValue(extractOp.offsets()); 116 uint64_t size = getFirstIntValue(extractOp.sizes()); 117 uint64_t stride = getFirstIntValue(extractOp.strides()); 118 if (stride != 1) 119 return failure(); 120 121 Value srcVector = adaptor.getOperands().front(); 122 123 // Extract vector<1xT> case. 124 if (dstType.isa<spirv::ScalarType>()) { 125 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp, 126 srcVector, offset); 127 return success(); 128 } 129 130 SmallVector<int32_t, 2> indices(size); 131 std::iota(indices.begin(), indices.end(), offset); 132 133 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 134 extractOp, dstType, srcVector, srcVector, 135 rewriter.getI32ArrayAttr(indices)); 136 137 return success(); 138 } 139 }; 140 141 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> { 142 using OpConversionPattern::OpConversionPattern; 143 144 LogicalResult 145 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 146 ConversionPatternRewriter &rewriter) const override { 147 if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) 148 return failure(); 149 rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>( 150 fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc()); 151 return success(); 152 } 153 }; 154 155 struct VectorInsertOpConvert final 156 : public OpConversionPattern<vector::InsertOp> { 157 using OpConversionPattern::OpConversionPattern; 158 159 LogicalResult 160 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 161 ConversionPatternRewriter &rewriter) const override { 162 // Special case for inserting scalar values into size-1 vectors. 163 if (insertOp.getSourceType().isIntOrFloat() && 164 insertOp.getDestVectorType().getNumElements() == 1) { 165 rewriter.replaceOp(insertOp, adaptor.source()); 166 return success(); 167 } 168 169 if (insertOp.getSourceType().isa<VectorType>() || 170 !spirv::CompositeType::isValid(insertOp.getDestVectorType())) 171 return failure(); 172 int32_t id = getFirstIntValue(insertOp.position()); 173 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 174 insertOp, adaptor.source(), adaptor.dest(), id); 175 return success(); 176 } 177 }; 178 179 struct VectorExtractElementOpConvert final 180 : public OpConversionPattern<vector::ExtractElementOp> { 181 using OpConversionPattern::OpConversionPattern; 182 183 LogicalResult 184 matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor, 185 ConversionPatternRewriter &rewriter) const override { 186 if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) 187 return failure(); 188 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( 189 extractElementOp, extractElementOp.getType(), adaptor.vector(), 190 extractElementOp.position()); 191 return success(); 192 } 193 }; 194 195 struct VectorInsertElementOpConvert final 196 : public OpConversionPattern<vector::InsertElementOp> { 197 using OpConversionPattern::OpConversionPattern; 198 199 LogicalResult 200 matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor, 201 ConversionPatternRewriter &rewriter) const override { 202 if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) 203 return failure(); 204 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( 205 insertElementOp, insertElementOp.getType(), insertElementOp.dest(), 206 adaptor.source(), insertElementOp.position()); 207 return success(); 208 } 209 }; 210 211 struct VectorInsertStridedSliceOpConvert final 212 : public OpConversionPattern<vector::InsertStridedSliceOp> { 213 using OpConversionPattern::OpConversionPattern; 214 215 LogicalResult 216 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, 217 ConversionPatternRewriter &rewriter) const override { 218 Value srcVector = adaptor.getOperands().front(); 219 Value dstVector = adaptor.getOperands().back(); 220 221 uint64_t stride = getFirstIntValue(insertOp.strides()); 222 if (stride != 1) 223 return failure(); 224 uint64_t offset = getFirstIntValue(insertOp.offsets()); 225 226 if (srcVector.getType().isa<spirv::ScalarType>()) { 227 assert(!dstVector.getType().isa<spirv::ScalarType>()); 228 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 229 insertOp, dstVector.getType(), srcVector, dstVector, 230 rewriter.getI32ArrayAttr(offset)); 231 return success(); 232 } 233 234 uint64_t totalSize = 235 dstVector.getType().cast<VectorType>().getNumElements(); 236 uint64_t insertSize = 237 srcVector.getType().cast<VectorType>().getNumElements(); 238 239 SmallVector<int32_t, 2> indices(totalSize); 240 std::iota(indices.begin(), indices.end(), 0); 241 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize, 242 totalSize); 243 244 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 245 insertOp, dstVector.getType(), dstVector, srcVector, 246 rewriter.getI32ArrayAttr(indices)); 247 248 return success(); 249 } 250 }; 251 252 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> { 253 public: 254 using OpConversionPattern<vector::SplatOp>::OpConversionPattern; 255 256 LogicalResult 257 matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor, 258 ConversionPatternRewriter &rewriter) const override { 259 VectorType dstVecType = op.getType(); 260 if (!spirv::CompositeType::isValid(dstVecType)) 261 return failure(); 262 SmallVector<Value, 4> source(dstVecType.getNumElements(), adaptor.input()); 263 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType, 264 source); 265 return success(); 266 } 267 }; 268 269 struct VectorShuffleOpConvert final 270 : public OpConversionPattern<vector::ShuffleOp> { 271 using OpConversionPattern::OpConversionPattern; 272 273 LogicalResult 274 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 275 ConversionPatternRewriter &rewriter) const override { 276 auto oldResultType = shuffleOp.getVectorType(); 277 if (!spirv::CompositeType::isValid(oldResultType)) 278 return failure(); 279 auto newResultType = getTypeConverter()->convertType(oldResultType); 280 281 auto oldSourceType = shuffleOp.getV1VectorType(); 282 if (oldSourceType.getNumElements() > 1) { 283 SmallVector<int32_t, 4> components = llvm::to_vector<4>( 284 llvm::map_range(shuffleOp.mask(), [](Attribute attr) -> int32_t { 285 return attr.cast<IntegerAttr>().getValue().getZExtValue(); 286 })); 287 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 288 shuffleOp, newResultType, adaptor.v1(), adaptor.v2(), 289 rewriter.getI32ArrayAttr(components)); 290 return success(); 291 } 292 293 SmallVector<Value, 2> oldOperands = {adaptor.v1(), adaptor.v2()}; 294 SmallVector<Value, 4> newOperands; 295 newOperands.reserve(oldResultType.getNumElements()); 296 for (const APInt &i : shuffleOp.mask().getAsValueRange<IntegerAttr>()) { 297 newOperands.push_back(oldOperands[i.getZExtValue()]); 298 } 299 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 300 shuffleOp, newResultType, newOperands); 301 302 return success(); 303 } 304 }; 305 306 } // namespace 307 308 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 309 RewritePatternSet &patterns) { 310 patterns.add<VectorBitcastConvert, VectorBroadcastConvert, 311 VectorExtractElementOpConvert, VectorExtractOpConvert, 312 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, 313 VectorInsertElementOpConvert, VectorInsertOpConvert, 314 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, 315 VectorSplatPattern>(typeConverter, patterns.getContext()); 316 } 317