1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Vector dialect to SPIRV dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 14 15 #include "../PassDetail.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/IR/BuiltinAttributes.h" 22 #include "mlir/IR/BuiltinTypes.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "llvm/ADT/ArrayRef.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include <numeric> 27 28 using namespace mlir; 29 30 /// Gets the first integer value from `attr`, assuming it is an integer array 31 /// attribute. 32 static uint64_t getFirstIntValue(ArrayAttr attr) { 33 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 34 } 35 36 namespace { 37 38 struct VectorBitcastConvert final 39 : public OpConversionPattern<vector::BitCastOp> { 40 using OpConversionPattern::OpConversionPattern; 41 42 LogicalResult 43 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor, 44 ConversionPatternRewriter &rewriter) const override { 45 auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); 46 if (!dstType) 47 return failure(); 48 49 if (dstType == adaptor.getSource().getType()) 50 rewriter.replaceOp(bitcastOp, adaptor.getSource()); 51 else 52 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType, 53 adaptor.getSource()); 54 55 return success(); 56 } 57 }; 58 59 struct VectorBroadcastConvert final 60 : public OpConversionPattern<vector::BroadcastOp> { 61 using OpConversionPattern::OpConversionPattern; 62 63 LogicalResult 64 matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, 65 ConversionPatternRewriter &rewriter) const override { 66 if (broadcastOp.getSource().getType().isa<VectorType>() || 67 !spirv::CompositeType::isValid(broadcastOp.getVectorType())) 68 return failure(); 69 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(), 70 adaptor.getSource()); 71 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 72 broadcastOp, broadcastOp.getVectorType(), source); 73 return success(); 74 } 75 }; 76 77 struct VectorExtractOpConvert final 78 : public OpConversionPattern<vector::ExtractOp> { 79 using OpConversionPattern::OpConversionPattern; 80 81 LogicalResult 82 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 83 ConversionPatternRewriter &rewriter) const override { 84 // Only support extracting a scalar value now. 85 VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>(); 86 if (resultVectorType && resultVectorType.getNumElements() > 1) 87 return failure(); 88 89 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 90 if (!dstType) 91 return failure(); 92 93 if (adaptor.getVector().getType().isa<spirv::ScalarType>()) { 94 rewriter.replaceOp(extractOp, adaptor.getVector()); 95 return success(); 96 } 97 98 int32_t id = getFirstIntValue(extractOp.getPosition()); 99 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 100 extractOp, adaptor.getVector(), id); 101 return success(); 102 } 103 }; 104 105 struct VectorExtractStridedSliceOpConvert final 106 : public OpConversionPattern<vector::ExtractStridedSliceOp> { 107 using OpConversionPattern::OpConversionPattern; 108 109 LogicalResult 110 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, 111 ConversionPatternRewriter &rewriter) const override { 112 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 113 if (!dstType) 114 return failure(); 115 116 uint64_t offset = getFirstIntValue(extractOp.getOffsets()); 117 uint64_t size = getFirstIntValue(extractOp.getSizes()); 118 uint64_t stride = getFirstIntValue(extractOp.getStrides()); 119 if (stride != 1) 120 return failure(); 121 122 Value srcVector = adaptor.getOperands().front(); 123 124 // Extract vector<1xT> case. 125 if (dstType.isa<spirv::ScalarType>()) { 126 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp, 127 srcVector, offset); 128 return success(); 129 } 130 131 SmallVector<int32_t, 2> indices(size); 132 std::iota(indices.begin(), indices.end(), offset); 133 134 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 135 extractOp, dstType, srcVector, srcVector, 136 rewriter.getI32ArrayAttr(indices)); 137 138 return success(); 139 } 140 }; 141 142 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> { 143 using OpConversionPattern::OpConversionPattern; 144 145 LogicalResult 146 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 147 ConversionPatternRewriter &rewriter) const override { 148 if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) 149 return failure(); 150 rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>( 151 fmaOp, fmaOp.getType(), adaptor.getLhs(), adaptor.getRhs(), 152 adaptor.getAcc()); 153 return success(); 154 } 155 }; 156 157 struct VectorInsertOpConvert final 158 : public OpConversionPattern<vector::InsertOp> { 159 using OpConversionPattern::OpConversionPattern; 160 161 LogicalResult 162 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 163 ConversionPatternRewriter &rewriter) const override { 164 // Special case for inserting scalar values into size-1 vectors. 165 if (insertOp.getSourceType().isIntOrFloat() && 166 insertOp.getDestVectorType().getNumElements() == 1) { 167 rewriter.replaceOp(insertOp, adaptor.getSource()); 168 return success(); 169 } 170 171 if (insertOp.getSourceType().isa<VectorType>() || 172 !spirv::CompositeType::isValid(insertOp.getDestVectorType())) 173 return failure(); 174 int32_t id = getFirstIntValue(insertOp.getPosition()); 175 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 176 insertOp, adaptor.getSource(), adaptor.getDest(), id); 177 return success(); 178 } 179 }; 180 181 struct VectorExtractElementOpConvert final 182 : public OpConversionPattern<vector::ExtractElementOp> { 183 using OpConversionPattern::OpConversionPattern; 184 185 LogicalResult 186 matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor, 187 ConversionPatternRewriter &rewriter) const override { 188 if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) 189 return failure(); 190 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( 191 extractElementOp, extractElementOp.getType(), adaptor.getVector(), 192 extractElementOp.getPosition()); 193 return success(); 194 } 195 }; 196 197 struct VectorInsertElementOpConvert final 198 : public OpConversionPattern<vector::InsertElementOp> { 199 using OpConversionPattern::OpConversionPattern; 200 201 LogicalResult 202 matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor, 203 ConversionPatternRewriter &rewriter) const override { 204 if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) 205 return failure(); 206 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( 207 insertElementOp, insertElementOp.getType(), insertElementOp.getDest(), 208 adaptor.getSource(), insertElementOp.getPosition()); 209 return success(); 210 } 211 }; 212 213 struct VectorInsertStridedSliceOpConvert final 214 : public OpConversionPattern<vector::InsertStridedSliceOp> { 215 using OpConversionPattern::OpConversionPattern; 216 217 LogicalResult 218 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, 219 ConversionPatternRewriter &rewriter) const override { 220 Value srcVector = adaptor.getOperands().front(); 221 Value dstVector = adaptor.getOperands().back(); 222 223 uint64_t stride = getFirstIntValue(insertOp.getStrides()); 224 if (stride != 1) 225 return failure(); 226 uint64_t offset = getFirstIntValue(insertOp.getOffsets()); 227 228 if (srcVector.getType().isa<spirv::ScalarType>()) { 229 assert(!dstVector.getType().isa<spirv::ScalarType>()); 230 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 231 insertOp, dstVector.getType(), srcVector, dstVector, 232 rewriter.getI32ArrayAttr(offset)); 233 return success(); 234 } 235 236 uint64_t totalSize = 237 dstVector.getType().cast<VectorType>().getNumElements(); 238 uint64_t insertSize = 239 srcVector.getType().cast<VectorType>().getNumElements(); 240 241 SmallVector<int32_t, 2> indices(totalSize); 242 std::iota(indices.begin(), indices.end(), 0); 243 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize, 244 totalSize); 245 246 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 247 insertOp, dstVector.getType(), dstVector, srcVector, 248 rewriter.getI32ArrayAttr(indices)); 249 250 return success(); 251 } 252 }; 253 254 struct VectorReductionPattern final 255 : public OpConversionPattern<vector::ReductionOp> { 256 using OpConversionPattern::OpConversionPattern; 257 258 LogicalResult 259 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor, 260 ConversionPatternRewriter &rewriter) const override { 261 Type resultType = typeConverter->convertType(reduceOp.getType()); 262 if (!resultType) 263 return failure(); 264 265 auto srcVectorType = adaptor.getVector().getType().dyn_cast<VectorType>(); 266 if (!srcVectorType || srcVectorType.getRank() != 1) 267 return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source"); 268 269 // Extract all elements. 270 int numElements = srcVectorType.getDimSize(0); 271 SmallVector<Value, 4> values; 272 values.reserve(numElements + (adaptor.getAcc() != nullptr)); 273 Location loc = reduceOp.getLoc(); 274 for (int i = 0; i < numElements; ++i) { 275 values.push_back(rewriter.create<spirv::CompositeExtractOp>( 276 loc, srcVectorType.getElementType(), adaptor.getVector(), 277 rewriter.getI32ArrayAttr({i}))); 278 } 279 if (Value acc = adaptor.getAcc()) 280 values.push_back(acc); 281 282 // Reduce them. 283 Value result = values.front(); 284 for (Value next : llvm::makeArrayRef(values).drop_front()) { 285 switch (reduceOp.getKind()) { 286 #define INT_FLOAT_CASE(kind, iop, fop) \ 287 case vector::CombiningKind::kind: \ 288 if (resultType.isa<IntegerType>()) { \ 289 result = rewriter.create<spirv::iop>(loc, resultType, result, next); \ 290 } else { \ 291 assert(resultType.isa<FloatType>()); \ 292 result = rewriter.create<spirv::fop>(loc, resultType, result, next); \ 293 } \ 294 break 295 296 INT_FLOAT_CASE(ADD, IAddOp, FAddOp); 297 INT_FLOAT_CASE(MUL, IMulOp, FMulOp); 298 299 case vector::CombiningKind::MINUI: 300 case vector::CombiningKind::MINSI: 301 case vector::CombiningKind::MINF: 302 case vector::CombiningKind::MAXUI: 303 case vector::CombiningKind::MAXSI: 304 case vector::CombiningKind::MAXF: 305 case vector::CombiningKind::AND: 306 case vector::CombiningKind::OR: 307 case vector::CombiningKind::XOR: 308 return rewriter.notifyMatchFailure(reduceOp, "unimplemented"); 309 } 310 } 311 312 rewriter.replaceOp(reduceOp, result); 313 return success(); 314 } 315 }; 316 317 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> { 318 public: 319 using OpConversionPattern<vector::SplatOp>::OpConversionPattern; 320 321 LogicalResult 322 matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor, 323 ConversionPatternRewriter &rewriter) const override { 324 VectorType dstVecType = op.getType(); 325 if (!spirv::CompositeType::isValid(dstVecType)) 326 return failure(); 327 SmallVector<Value, 4> source(dstVecType.getNumElements(), 328 adaptor.getInput()); 329 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType, 330 source); 331 return success(); 332 } 333 }; 334 335 struct VectorShuffleOpConvert final 336 : public OpConversionPattern<vector::ShuffleOp> { 337 using OpConversionPattern::OpConversionPattern; 338 339 LogicalResult 340 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 341 ConversionPatternRewriter &rewriter) const override { 342 auto oldResultType = shuffleOp.getVectorType(); 343 if (!spirv::CompositeType::isValid(oldResultType)) 344 return failure(); 345 auto newResultType = getTypeConverter()->convertType(oldResultType); 346 347 auto oldSourceType = shuffleOp.getV1VectorType(); 348 if (oldSourceType.getNumElements() > 1) { 349 SmallVector<int32_t, 4> components = llvm::to_vector<4>( 350 llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t { 351 return attr.cast<IntegerAttr>().getValue().getZExtValue(); 352 })); 353 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 354 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(), 355 rewriter.getI32ArrayAttr(components)); 356 return success(); 357 } 358 359 SmallVector<Value, 2> oldOperands = {adaptor.getV1(), adaptor.getV2()}; 360 SmallVector<Value, 4> newOperands; 361 newOperands.reserve(oldResultType.getNumElements()); 362 for (const APInt &i : shuffleOp.getMask().getAsValueRange<IntegerAttr>()) { 363 newOperands.push_back(oldOperands[i.getZExtValue()]); 364 } 365 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 366 shuffleOp, newResultType, newOperands); 367 368 return success(); 369 } 370 }; 371 372 } // namespace 373 374 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 375 RewritePatternSet &patterns) { 376 patterns.add<VectorBitcastConvert, VectorBroadcastConvert, 377 VectorExtractElementOpConvert, VectorExtractOpConvert, 378 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, 379 VectorInsertElementOpConvert, VectorInsertOpConvert, 380 VectorReductionPattern, VectorInsertStridedSliceOpConvert, 381 VectorShuffleOpConvert, VectorSplatPattern>( 382 typeConverter, patterns.getContext()); 383 } 384