1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Vector dialect to SPIRV dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 14 15 #include "../PassDetail.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include <numeric> 23 24 using namespace mlir; 25 26 /// Gets the first integer value from `attr`, assuming it is an integer array 27 /// attribute. 28 static uint64_t getFirstIntValue(ArrayAttr attr) { 29 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 30 } 31 32 namespace { 33 34 struct VectorBitcastConvert final 35 : public OpConversionPattern<vector::BitCastOp> { 36 using OpConversionPattern::OpConversionPattern; 37 38 LogicalResult 39 matchAndRewrite(vector::BitCastOp bitcastOp, ArrayRef<Value> operands, 40 ConversionPatternRewriter &rewriter) const override { 41 auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); 42 if (!dstType) 43 return failure(); 44 45 vector::BitCastOp::Adaptor adaptor(operands); 46 if (dstType == adaptor.source().getType()) 47 rewriter.replaceOp(bitcastOp, adaptor.source()); 48 else 49 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType, 50 adaptor.source()); 51 52 return success(); 53 } 54 }; 55 56 struct VectorBroadcastConvert final 57 : public OpConversionPattern<vector::BroadcastOp> { 58 using OpConversionPattern::OpConversionPattern; 59 60 LogicalResult 61 matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands, 62 ConversionPatternRewriter &rewriter) const override { 63 if (broadcastOp.source().getType().isa<VectorType>() || 64 !spirv::CompositeType::isValid(broadcastOp.getVectorType())) 65 return failure(); 66 vector::BroadcastOp::Adaptor adaptor(operands); 67 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(), 68 adaptor.source()); 69 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 70 broadcastOp, broadcastOp.getVectorType(), source); 71 return success(); 72 } 73 }; 74 75 struct VectorExtractOpConvert final 76 : public OpConversionPattern<vector::ExtractOp> { 77 using OpConversionPattern::OpConversionPattern; 78 79 LogicalResult 80 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 81 ConversionPatternRewriter &rewriter) const override { 82 // Only support extracting a scalar value now. 83 VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>(); 84 if (resultVectorType && resultVectorType.getNumElements() > 1) 85 return failure(); 86 87 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 88 if (!dstType) 89 return failure(); 90 91 vector::ExtractOp::Adaptor adaptor(operands); 92 if (adaptor.vector().getType().isa<spirv::ScalarType>()) { 93 rewriter.replaceOp(extractOp, adaptor.vector()); 94 return success(); 95 } 96 97 int32_t id = getFirstIntValue(extractOp.position()); 98 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 99 extractOp, adaptor.vector(), id); 100 return success(); 101 } 102 }; 103 104 struct VectorExtractStridedSliceOpConvert final 105 : public OpConversionPattern<vector::ExtractStridedSliceOp> { 106 using OpConversionPattern::OpConversionPattern; 107 108 LogicalResult 109 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 110 ArrayRef<Value> operands, 111 ConversionPatternRewriter &rewriter) const override { 112 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 113 if (!dstType) 114 return failure(); 115 116 // Extract vector<1xT> not supported yet. 117 if (dstType.isa<spirv::ScalarType>()) 118 return failure(); 119 120 uint64_t offset = getFirstIntValue(extractOp.offsets()); 121 uint64_t size = getFirstIntValue(extractOp.sizes()); 122 uint64_t stride = getFirstIntValue(extractOp.strides()); 123 if (stride != 1) 124 return failure(); 125 126 Value srcVector = operands.front(); 127 128 SmallVector<int32_t, 2> indices(size); 129 std::iota(indices.begin(), indices.end(), offset); 130 131 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 132 extractOp, dstType, srcVector, srcVector, 133 rewriter.getI32ArrayAttr(indices)); 134 135 return success(); 136 } 137 }; 138 139 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> { 140 using OpConversionPattern::OpConversionPattern; 141 142 LogicalResult 143 matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, 144 ConversionPatternRewriter &rewriter) const override { 145 if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) 146 return failure(); 147 vector::FMAOp::Adaptor adaptor(operands); 148 rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>( 149 fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc()); 150 return success(); 151 } 152 }; 153 154 struct VectorInsertOpConvert final 155 : public OpConversionPattern<vector::InsertOp> { 156 using OpConversionPattern::OpConversionPattern; 157 158 LogicalResult 159 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 160 ConversionPatternRewriter &rewriter) const override { 161 if (insertOp.getSourceType().isa<VectorType>() || 162 !spirv::CompositeType::isValid(insertOp.getDestVectorType())) 163 return failure(); 164 vector::InsertOp::Adaptor adaptor(operands); 165 int32_t id = getFirstIntValue(insertOp.position()); 166 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 167 insertOp, adaptor.source(), adaptor.dest(), id); 168 return success(); 169 } 170 }; 171 172 struct VectorExtractElementOpConvert final 173 : public OpConversionPattern<vector::ExtractElementOp> { 174 using OpConversionPattern::OpConversionPattern; 175 176 LogicalResult 177 matchAndRewrite(vector::ExtractElementOp extractElementOp, 178 ArrayRef<Value> operands, 179 ConversionPatternRewriter &rewriter) const override { 180 if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) 181 return failure(); 182 vector::ExtractElementOp::Adaptor adaptor(operands); 183 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( 184 extractElementOp, extractElementOp.getType(), adaptor.vector(), 185 extractElementOp.position()); 186 return success(); 187 } 188 }; 189 190 struct VectorInsertElementOpConvert final 191 : public OpConversionPattern<vector::InsertElementOp> { 192 using OpConversionPattern::OpConversionPattern; 193 194 LogicalResult 195 matchAndRewrite(vector::InsertElementOp insertElementOp, 196 ArrayRef<Value> operands, 197 ConversionPatternRewriter &rewriter) const override { 198 if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) 199 return failure(); 200 vector::InsertElementOp::Adaptor adaptor(operands); 201 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( 202 insertElementOp, insertElementOp.getType(), insertElementOp.dest(), 203 adaptor.source(), insertElementOp.position()); 204 return success(); 205 } 206 }; 207 208 struct VectorInsertStridedSliceOpConvert final 209 : public OpConversionPattern<vector::InsertStridedSliceOp> { 210 using OpConversionPattern::OpConversionPattern; 211 212 LogicalResult 213 matchAndRewrite(vector::InsertStridedSliceOp insertOp, 214 ArrayRef<Value> operands, 215 ConversionPatternRewriter &rewriter) const override { 216 Value srcVector = operands.front(); 217 Value dstVector = operands.back(); 218 219 // Insert scalar values not supported yet. 220 if (srcVector.getType().isa<spirv::ScalarType>() || 221 dstVector.getType().isa<spirv::ScalarType>()) 222 return failure(); 223 224 uint64_t stride = getFirstIntValue(insertOp.strides()); 225 if (stride != 1) 226 return failure(); 227 228 uint64_t totalSize = 229 dstVector.getType().cast<VectorType>().getNumElements(); 230 uint64_t insertSize = 231 srcVector.getType().cast<VectorType>().getNumElements(); 232 uint64_t offset = getFirstIntValue(insertOp.offsets()); 233 234 SmallVector<int32_t, 2> indices(totalSize); 235 std::iota(indices.begin(), indices.end(), 0); 236 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize, 237 totalSize); 238 239 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 240 insertOp, dstVector.getType(), dstVector, srcVector, 241 rewriter.getI32ArrayAttr(indices)); 242 243 return success(); 244 } 245 }; 246 247 } // namespace 248 249 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 250 RewritePatternSet &patterns) { 251 patterns.add<VectorBitcastConvert, VectorBroadcastConvert, 252 VectorExtractElementOpConvert, VectorExtractOpConvert, 253 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, 254 VectorInsertElementOpConvert, VectorInsertOpConvert, 255 VectorInsertStridedSliceOpConvert>(typeConverter, 256 patterns.getContext()); 257 } 258