1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Vector dialect to SPIRV dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 14 15 #include "../PassDetail.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include <numeric> 23 24 using namespace mlir; 25 26 /// Gets the first integer value from `attr`, assuming it is an integer array 27 /// attribute. 28 static uint64_t getFirstIntValue(ArrayAttr attr) { 29 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 30 } 31 32 namespace { 33 34 struct VectorBitcastConvert final 35 : public OpConversionPattern<vector::BitCastOp> { 36 using OpConversionPattern::OpConversionPattern; 37 38 LogicalResult 39 matchAndRewrite(vector::BitCastOp bitcastOp, ArrayRef<Value> operands, 40 ConversionPatternRewriter &rewriter) const override { 41 auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); 42 if (!dstType) 43 return failure(); 44 45 vector::BitCastOp::Adaptor adaptor(operands); 46 if (dstType == adaptor.source().getType()) 47 rewriter.replaceOp(bitcastOp, adaptor.source()); 48 else 49 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType, 50 adaptor.source()); 51 52 return success(); 53 } 54 }; 55 56 struct VectorBroadcastConvert final 57 : public OpConversionPattern<vector::BroadcastOp> { 58 using OpConversionPattern::OpConversionPattern; 59 60 LogicalResult 61 matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands, 62 ConversionPatternRewriter &rewriter) const override { 63 if (broadcastOp.source().getType().isa<VectorType>() || 64 !spirv::CompositeType::isValid(broadcastOp.getVectorType())) 65 return failure(); 66 vector::BroadcastOp::Adaptor adaptor(operands); 67 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(), 68 adaptor.source()); 69 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 70 broadcastOp, broadcastOp.getVectorType(), source); 71 return success(); 72 } 73 }; 74 75 struct VectorExtractOpConvert final 76 : public OpConversionPattern<vector::ExtractOp> { 77 using OpConversionPattern::OpConversionPattern; 78 79 LogicalResult 80 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 81 ConversionPatternRewriter &rewriter) const override { 82 // Only support extracting a scalar value now. 83 VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>(); 84 if (resultVectorType && resultVectorType.getNumElements() > 1) 85 return failure(); 86 87 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 88 if (!dstType) 89 return failure(); 90 91 vector::ExtractOp::Adaptor adaptor(operands); 92 if (adaptor.vector().getType().isa<spirv::ScalarType>()) { 93 rewriter.replaceOp(extractOp, adaptor.vector()); 94 return success(); 95 } 96 97 int32_t id = getFirstIntValue(extractOp.position()); 98 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 99 extractOp, adaptor.vector(), id); 100 return success(); 101 } 102 }; 103 104 struct VectorExtractStridedSliceOpConvert final 105 : public OpConversionPattern<vector::ExtractStridedSliceOp> { 106 using OpConversionPattern::OpConversionPattern; 107 108 LogicalResult 109 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 110 ArrayRef<Value> operands, 111 ConversionPatternRewriter &rewriter) const override { 112 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 113 if (!dstType) 114 return failure(); 115 116 117 uint64_t offset = getFirstIntValue(extractOp.offsets()); 118 uint64_t size = getFirstIntValue(extractOp.sizes()); 119 uint64_t stride = getFirstIntValue(extractOp.strides()); 120 if (stride != 1) 121 return failure(); 122 123 Value srcVector = operands.front(); 124 125 // Extract vector<1xT> case. 126 if (dstType.isa<spirv::ScalarType>()) { 127 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp, 128 srcVector, offset); 129 return success(); 130 } 131 132 SmallVector<int32_t, 2> indices(size); 133 std::iota(indices.begin(), indices.end(), offset); 134 135 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 136 extractOp, dstType, srcVector, srcVector, 137 rewriter.getI32ArrayAttr(indices)); 138 139 return success(); 140 } 141 }; 142 143 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> { 144 using OpConversionPattern::OpConversionPattern; 145 146 LogicalResult 147 matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, 148 ConversionPatternRewriter &rewriter) const override { 149 if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) 150 return failure(); 151 vector::FMAOp::Adaptor adaptor(operands); 152 rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>( 153 fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc()); 154 return success(); 155 } 156 }; 157 158 struct VectorInsertOpConvert final 159 : public OpConversionPattern<vector::InsertOp> { 160 using OpConversionPattern::OpConversionPattern; 161 162 LogicalResult 163 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 164 ConversionPatternRewriter &rewriter) const override { 165 if (insertOp.getSourceType().isa<VectorType>() || 166 !spirv::CompositeType::isValid(insertOp.getDestVectorType())) 167 return failure(); 168 vector::InsertOp::Adaptor adaptor(operands); 169 int32_t id = getFirstIntValue(insertOp.position()); 170 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 171 insertOp, adaptor.source(), adaptor.dest(), id); 172 return success(); 173 } 174 }; 175 176 struct VectorExtractElementOpConvert final 177 : public OpConversionPattern<vector::ExtractElementOp> { 178 using OpConversionPattern::OpConversionPattern; 179 180 LogicalResult 181 matchAndRewrite(vector::ExtractElementOp extractElementOp, 182 ArrayRef<Value> operands, 183 ConversionPatternRewriter &rewriter) const override { 184 if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) 185 return failure(); 186 vector::ExtractElementOp::Adaptor adaptor(operands); 187 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( 188 extractElementOp, extractElementOp.getType(), adaptor.vector(), 189 extractElementOp.position()); 190 return success(); 191 } 192 }; 193 194 struct VectorInsertElementOpConvert final 195 : public OpConversionPattern<vector::InsertElementOp> { 196 using OpConversionPattern::OpConversionPattern; 197 198 LogicalResult 199 matchAndRewrite(vector::InsertElementOp insertElementOp, 200 ArrayRef<Value> operands, 201 ConversionPatternRewriter &rewriter) const override { 202 if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) 203 return failure(); 204 vector::InsertElementOp::Adaptor adaptor(operands); 205 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( 206 insertElementOp, insertElementOp.getType(), insertElementOp.dest(), 207 adaptor.source(), insertElementOp.position()); 208 return success(); 209 } 210 }; 211 212 struct VectorInsertStridedSliceOpConvert final 213 : public OpConversionPattern<vector::InsertStridedSliceOp> { 214 using OpConversionPattern::OpConversionPattern; 215 216 LogicalResult 217 matchAndRewrite(vector::InsertStridedSliceOp insertOp, 218 ArrayRef<Value> operands, 219 ConversionPatternRewriter &rewriter) const override { 220 Value srcVector = operands.front(); 221 Value dstVector = operands.back(); 222 223 // Insert scalar values not supported yet. 224 if (srcVector.getType().isa<spirv::ScalarType>() || 225 dstVector.getType().isa<spirv::ScalarType>()) 226 return failure(); 227 228 uint64_t stride = getFirstIntValue(insertOp.strides()); 229 if (stride != 1) 230 return failure(); 231 232 uint64_t totalSize = 233 dstVector.getType().cast<VectorType>().getNumElements(); 234 uint64_t insertSize = 235 srcVector.getType().cast<VectorType>().getNumElements(); 236 uint64_t offset = getFirstIntValue(insertOp.offsets()); 237 238 SmallVector<int32_t, 2> indices(totalSize); 239 std::iota(indices.begin(), indices.end(), 0); 240 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize, 241 totalSize); 242 243 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 244 insertOp, dstVector.getType(), dstVector, srcVector, 245 rewriter.getI32ArrayAttr(indices)); 246 247 return success(); 248 } 249 }; 250 251 } // namespace 252 253 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 254 RewritePatternSet &patterns) { 255 patterns.add<VectorBitcastConvert, VectorBroadcastConvert, 256 VectorExtractElementOpConvert, VectorExtractOpConvert, 257 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, 258 VectorInsertElementOpConvert, VectorInsertOpConvert, 259 VectorInsertStridedSliceOpConvert>(typeConverter, 260 patterns.getContext()); 261 } 262