1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Vector dialect to SPIRV dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 14 15 #include "../PassDetail.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include <numeric> 23 24 using namespace mlir; 25 26 /// Gets the first integer value from `attr`, assuming it is an integer array 27 /// attribute. 28 static uint64_t getFirstIntValue(ArrayAttr attr) { 29 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 30 } 31 32 namespace { 33 34 struct VectorBitcastConvert final 35 : public OpConversionPattern<vector::BitCastOp> { 36 using OpConversionPattern::OpConversionPattern; 37 38 LogicalResult 39 matchAndRewrite(vector::BitCastOp bitcastOp, ArrayRef<Value> operands, 40 ConversionPatternRewriter &rewriter) const override { 41 auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); 42 if (!dstType) 43 return failure(); 44 45 vector::BitCastOp::Adaptor adaptor(operands); 46 if (dstType == adaptor.source().getType()) 47 rewriter.replaceOp(bitcastOp, adaptor.source()); 48 else 49 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType, 50 adaptor.source()); 51 52 return success(); 53 } 54 }; 55 56 struct VectorBroadcastConvert final 57 : public OpConversionPattern<vector::BroadcastOp> { 58 using OpConversionPattern::OpConversionPattern; 59 60 LogicalResult 61 matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands, 62 ConversionPatternRewriter &rewriter) const override { 63 if (broadcastOp.source().getType().isa<VectorType>() || 64 !spirv::CompositeType::isValid(broadcastOp.getVectorType())) 65 return failure(); 66 vector::BroadcastOp::Adaptor adaptor(operands); 67 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(), 68 adaptor.source()); 69 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 70 broadcastOp, broadcastOp.getVectorType(), source); 71 return success(); 72 } 73 }; 74 75 struct VectorExtractOpConvert final 76 : public OpConversionPattern<vector::ExtractOp> { 77 using OpConversionPattern::OpConversionPattern; 78 79 LogicalResult 80 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 81 ConversionPatternRewriter &rewriter) const override { 82 // Only support extracting a scalar value now. 83 VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>(); 84 if (resultVectorType && resultVectorType.getNumElements() > 1) 85 return failure(); 86 87 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 88 if (!dstType) 89 return failure(); 90 91 vector::ExtractOp::Adaptor adaptor(operands); 92 int32_t id = getFirstIntValue(extractOp.position()); 93 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 94 extractOp, adaptor.vector(), id); 95 return success(); 96 } 97 }; 98 99 struct VectorExtractStridedSliceOpConvert final 100 : public OpConversionPattern<vector::ExtractStridedSliceOp> { 101 using OpConversionPattern::OpConversionPattern; 102 103 LogicalResult 104 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 105 ArrayRef<Value> operands, 106 ConversionPatternRewriter &rewriter) const override { 107 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 108 if (!dstType) 109 return failure(); 110 111 // Extract vector<1xT> not supported yet. 112 if (dstType.isa<spirv::ScalarType>()) 113 return failure(); 114 115 uint64_t offset = getFirstIntValue(extractOp.offsets()); 116 uint64_t size = getFirstIntValue(extractOp.sizes()); 117 uint64_t stride = getFirstIntValue(extractOp.strides()); 118 if (stride != 1) 119 return failure(); 120 121 Value srcVector = operands.front(); 122 123 SmallVector<int32_t, 2> indices(size); 124 std::iota(indices.begin(), indices.end(), offset); 125 126 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 127 extractOp, dstType, srcVector, srcVector, 128 rewriter.getI32ArrayAttr(indices)); 129 130 return success(); 131 } 132 }; 133 134 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> { 135 using OpConversionPattern::OpConversionPattern; 136 137 LogicalResult 138 matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, 139 ConversionPatternRewriter &rewriter) const override { 140 if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) 141 return failure(); 142 vector::FMAOp::Adaptor adaptor(operands); 143 rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>( 144 fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc()); 145 return success(); 146 } 147 }; 148 149 struct VectorInsertOpConvert final 150 : public OpConversionPattern<vector::InsertOp> { 151 using OpConversionPattern::OpConversionPattern; 152 153 LogicalResult 154 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 155 ConversionPatternRewriter &rewriter) const override { 156 if (insertOp.getSourceType().isa<VectorType>() || 157 !spirv::CompositeType::isValid(insertOp.getDestVectorType())) 158 return failure(); 159 vector::InsertOp::Adaptor adaptor(operands); 160 int32_t id = getFirstIntValue(insertOp.position()); 161 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 162 insertOp, adaptor.source(), adaptor.dest(), id); 163 return success(); 164 } 165 }; 166 167 struct VectorExtractElementOpConvert final 168 : public OpConversionPattern<vector::ExtractElementOp> { 169 using OpConversionPattern::OpConversionPattern; 170 171 LogicalResult 172 matchAndRewrite(vector::ExtractElementOp extractElementOp, 173 ArrayRef<Value> operands, 174 ConversionPatternRewriter &rewriter) const override { 175 if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) 176 return failure(); 177 vector::ExtractElementOp::Adaptor adaptor(operands); 178 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( 179 extractElementOp, extractElementOp.getType(), adaptor.vector(), 180 extractElementOp.position()); 181 return success(); 182 } 183 }; 184 185 struct VectorInsertElementOpConvert final 186 : public OpConversionPattern<vector::InsertElementOp> { 187 using OpConversionPattern::OpConversionPattern; 188 189 LogicalResult 190 matchAndRewrite(vector::InsertElementOp insertElementOp, 191 ArrayRef<Value> operands, 192 ConversionPatternRewriter &rewriter) const override { 193 if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) 194 return failure(); 195 vector::InsertElementOp::Adaptor adaptor(operands); 196 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( 197 insertElementOp, insertElementOp.getType(), insertElementOp.dest(), 198 adaptor.source(), insertElementOp.position()); 199 return success(); 200 } 201 }; 202 203 struct VectorInsertStridedSliceOpConvert final 204 : public OpConversionPattern<vector::InsertStridedSliceOp> { 205 using OpConversionPattern::OpConversionPattern; 206 207 LogicalResult 208 matchAndRewrite(vector::InsertStridedSliceOp insertOp, 209 ArrayRef<Value> operands, 210 ConversionPatternRewriter &rewriter) const override { 211 Value srcVector = operands.front(); 212 Value dstVector = operands.back(); 213 214 // Insert scalar values not supported yet. 215 if (srcVector.getType().isa<spirv::ScalarType>() || 216 dstVector.getType().isa<spirv::ScalarType>()) 217 return failure(); 218 219 uint64_t stride = getFirstIntValue(insertOp.strides()); 220 if (stride != 1) 221 return failure(); 222 223 uint64_t totalSize = 224 dstVector.getType().cast<VectorType>().getNumElements(); 225 uint64_t insertSize = 226 srcVector.getType().cast<VectorType>().getNumElements(); 227 uint64_t offset = getFirstIntValue(insertOp.offsets()); 228 229 SmallVector<int32_t, 2> indices(totalSize); 230 std::iota(indices.begin(), indices.end(), 0); 231 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize, 232 totalSize); 233 234 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 235 insertOp, dstVector.getType(), dstVector, srcVector, 236 rewriter.getI32ArrayAttr(indices)); 237 238 return success(); 239 } 240 }; 241 242 } // namespace 243 244 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 245 RewritePatternSet &patterns) { 246 patterns.add<VectorBitcastConvert, VectorBroadcastConvert, 247 VectorExtractElementOpConvert, VectorExtractOpConvert, 248 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, 249 VectorInsertElementOpConvert, VectorInsertOpConvert, 250 VectorInsertStridedSliceOpConvert>(typeConverter, 251 patterns.getContext()); 252 } 253