1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Vector/VectorTransforms.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/Support/MathExtras.h" 20 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 using namespace mlir::vector; 25 26 // Helper to reduce vector type by one rank at front. 27 static VectorType reducedVectorTypeFront(VectorType tp) { 28 assert((tp.getRank() > 1) && "unlowerable vector type"); 29 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 30 } 31 32 // Helper to reduce vector type by *all* but one rank at back. 33 static VectorType reducedVectorTypeBack(VectorType tp) { 34 assert((tp.getRank() > 1) && "unlowerable vector type"); 35 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 36 } 37 38 // Helper that picks the proper sequence for inserting. 39 static Value insertOne(ConversionPatternRewriter &rewriter, 40 LLVMTypeConverter &typeConverter, Location loc, 41 Value val1, Value val2, Type llvmType, int64_t rank, 42 int64_t pos) { 43 assert(rank > 0 && "0-D vector corner case should have been handled already"); 44 if (rank == 1) { 45 auto idxType = rewriter.getIndexType(); 46 auto constant = rewriter.create<LLVM::ConstantOp>( 47 loc, typeConverter.convertType(idxType), 48 rewriter.getIntegerAttr(idxType, pos)); 49 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 50 constant); 51 } 52 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 53 rewriter.getI64ArrayAttr(pos)); 54 } 55 56 // Helper that picks the proper sequence for extracting. 57 static Value extractOne(ConversionPatternRewriter &rewriter, 58 LLVMTypeConverter &typeConverter, Location loc, 59 Value val, Type llvmType, int64_t rank, int64_t pos) { 60 assert(rank > 0 && "0-D vector corner case should have been handled already"); 61 if (rank == 1) { 62 auto idxType = rewriter.getIndexType(); 63 auto constant = rewriter.create<LLVM::ConstantOp>( 64 loc, typeConverter.convertType(idxType), 65 rewriter.getIntegerAttr(idxType, pos)); 66 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 67 constant); 68 } 69 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 70 rewriter.getI64ArrayAttr(pos)); 71 } 72 73 // Helper that returns data layout alignment of a memref. 74 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 75 MemRefType memrefType, unsigned &align) { 76 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 77 if (!elementTy) 78 return failure(); 79 80 // TODO: this should use the MLIR data layout when it becomes available and 81 // stop depending on translation. 82 llvm::LLVMContext llvmContext; 83 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 84 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 85 return success(); 86 } 87 88 // Return the minimal alignment value that satisfies all the AssumeAlignment 89 // uses of `value`. If no such uses exist, return 1. 90 static unsigned getAssumedAlignment(Value value) { 91 unsigned align = 1; 92 for (auto &u : value.getUses()) { 93 Operation *owner = u.getOwner(); 94 if (auto op = dyn_cast<memref::AssumeAlignmentOp>(owner)) 95 align = mlir::lcm(align, op.alignment()); 96 } 97 return align; 98 } 99 100 // Helper that returns data layout alignment of a memref associated with a 101 // load, store, scatter, or gather op, including additional information from 102 // assume_alignment calls on the source of the transfer 103 template <class OpAdaptor> 104 LogicalResult getMemRefOpAlignment(LLVMTypeConverter &typeConverter, 105 OpAdaptor op, unsigned &align) { 106 if (failed(getMemRefAlignment(typeConverter, op.getMemRefType(), align))) 107 return failure(); 108 align = std::max(align, getAssumedAlignment(op.base())); 109 return success(); 110 } 111 112 // Add an index vector component to a base pointer. This almost always succeeds 113 // unless the last stride is non-unit or the memory space is not zero. 114 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 115 Location loc, Value memref, Value base, 116 Value index, MemRefType memRefType, 117 VectorType vType, Value &ptrs) { 118 int64_t offset; 119 SmallVector<int64_t, 4> strides; 120 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 121 if (failed(successStrides) || strides.back() != 1 || 122 memRefType.getMemorySpaceAsInt() != 0) 123 return failure(); 124 auto pType = MemRefDescriptor(memref).getElementPtrType(); 125 auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); 126 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index); 127 return success(); 128 } 129 130 // Casts a strided element pointer to a vector pointer. The vector pointer 131 // will be in the same address space as the incoming memref type. 132 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 133 Value ptr, MemRefType memRefType, Type vt) { 134 auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt()); 135 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 136 } 137 138 namespace { 139 140 /// Conversion pattern for a vector.bitcast. 141 class VectorBitCastOpConversion 142 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 143 public: 144 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 145 146 LogicalResult 147 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 148 ConversionPatternRewriter &rewriter) const override { 149 // Only 1-D vectors can be lowered to LLVM. 150 VectorType resultTy = bitCastOp.getType(); 151 if (resultTy.getRank() != 1) 152 return failure(); 153 Type newResultTy = typeConverter->convertType(resultTy); 154 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 155 adaptor.getOperands()[0]); 156 return success(); 157 } 158 }; 159 160 /// Conversion pattern for a vector.matrix_multiply. 161 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 162 class VectorMatmulOpConversion 163 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 164 public: 165 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 166 167 LogicalResult 168 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 169 ConversionPatternRewriter &rewriter) const override { 170 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 171 matmulOp, typeConverter->convertType(matmulOp.res().getType()), 172 adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), 173 matmulOp.lhs_columns(), matmulOp.rhs_columns()); 174 return success(); 175 } 176 }; 177 178 /// Conversion pattern for a vector.flat_transpose. 179 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 180 class VectorFlatTransposeOpConversion 181 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 182 public: 183 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 184 185 LogicalResult 186 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 187 ConversionPatternRewriter &rewriter) const override { 188 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 189 transOp, typeConverter->convertType(transOp.res().getType()), 190 adaptor.matrix(), transOp.rows(), transOp.columns()); 191 return success(); 192 } 193 }; 194 195 /// Overloaded utility that replaces a vector.load, vector.store, 196 /// vector.maskedload and vector.maskedstore with their respective LLVM 197 /// couterparts. 198 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 199 vector::LoadOpAdaptor adaptor, 200 VectorType vectorTy, Value ptr, unsigned align, 201 ConversionPatternRewriter &rewriter) { 202 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align); 203 } 204 205 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 206 vector::MaskedLoadOpAdaptor adaptor, 207 VectorType vectorTy, Value ptr, unsigned align, 208 ConversionPatternRewriter &rewriter) { 209 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 210 loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align); 211 } 212 213 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 214 vector::StoreOpAdaptor adaptor, 215 VectorType vectorTy, Value ptr, unsigned align, 216 ConversionPatternRewriter &rewriter) { 217 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(), 218 ptr, align); 219 } 220 221 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 222 vector::MaskedStoreOpAdaptor adaptor, 223 VectorType vectorTy, Value ptr, unsigned align, 224 ConversionPatternRewriter &rewriter) { 225 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 226 storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align); 227 } 228 229 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 230 /// vector.maskedstore. 231 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 232 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 233 public: 234 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 235 236 LogicalResult 237 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 238 typename LoadOrStoreOp::Adaptor adaptor, 239 ConversionPatternRewriter &rewriter) const override { 240 // Only 1-D vectors can be lowered to LLVM. 241 VectorType vectorTy = loadOrStoreOp.getVectorType(); 242 if (vectorTy.getRank() > 1) 243 return failure(); 244 245 auto loc = loadOrStoreOp->getLoc(); 246 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 247 248 // Resolve alignment. 249 unsigned align; 250 if (failed(getMemRefOpAlignment(*this->getTypeConverter(), loadOrStoreOp, 251 align))) 252 return failure(); 253 254 // Resolve address. 255 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 256 .template cast<VectorType>(); 257 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(), 258 adaptor.indices(), rewriter); 259 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); 260 261 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 262 return success(); 263 } 264 }; 265 266 /// Conversion pattern for a vector.gather. 267 class VectorGatherOpConversion 268 : public ConvertOpToLLVMPattern<vector::GatherOp> { 269 public: 270 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 271 272 LogicalResult 273 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 274 ConversionPatternRewriter &rewriter) const override { 275 auto loc = gather->getLoc(); 276 MemRefType memRefType = gather.getMemRefType(); 277 278 // Resolve alignment. 279 unsigned align; 280 if (failed(getMemRefOpAlignment(*getTypeConverter(), gather, align))) 281 return failure(); 282 283 // Resolve address. 284 Value ptrs; 285 VectorType vType = gather.getVectorType(); 286 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 287 adaptor.indices(), rewriter); 288 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, 289 adaptor.index_vec(), memRefType, vType, ptrs))) 290 return failure(); 291 292 // Replace with the gather intrinsic. 293 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 294 gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 295 adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 296 return success(); 297 } 298 }; 299 300 /// Conversion pattern for a vector.scatter. 301 class VectorScatterOpConversion 302 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 303 public: 304 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 305 306 LogicalResult 307 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 308 ConversionPatternRewriter &rewriter) const override { 309 auto loc = scatter->getLoc(); 310 MemRefType memRefType = scatter.getMemRefType(); 311 312 // Resolve alignment. 313 unsigned align; 314 if (failed(getMemRefOpAlignment(*getTypeConverter(), scatter, align))) 315 return failure(); 316 317 // Resolve address. 318 Value ptrs; 319 VectorType vType = scatter.getVectorType(); 320 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 321 adaptor.indices(), rewriter); 322 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, 323 adaptor.index_vec(), memRefType, vType, ptrs))) 324 return failure(); 325 326 // Replace with the scatter intrinsic. 327 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 328 scatter, adaptor.valueToStore(), ptrs, adaptor.mask(), 329 rewriter.getI32IntegerAttr(align)); 330 return success(); 331 } 332 }; 333 334 /// Conversion pattern for a vector.expandload. 335 class VectorExpandLoadOpConversion 336 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 337 public: 338 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 339 340 LogicalResult 341 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 342 ConversionPatternRewriter &rewriter) const override { 343 auto loc = expand->getLoc(); 344 MemRefType memRefType = expand.getMemRefType(); 345 346 // Resolve address. 347 auto vtype = typeConverter->convertType(expand.getVectorType()); 348 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 349 adaptor.indices(), rewriter); 350 351 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 352 expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru()); 353 return success(); 354 } 355 }; 356 357 /// Conversion pattern for a vector.compressstore. 358 class VectorCompressStoreOpConversion 359 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 360 public: 361 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 362 363 LogicalResult 364 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 365 ConversionPatternRewriter &rewriter) const override { 366 auto loc = compress->getLoc(); 367 MemRefType memRefType = compress.getMemRefType(); 368 369 // Resolve address. 370 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 371 adaptor.indices(), rewriter); 372 373 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 374 compress, adaptor.valueToStore(), ptr, adaptor.mask()); 375 return success(); 376 } 377 }; 378 379 /// Conversion pattern for all vector reductions. 380 class VectorReductionOpConversion 381 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 382 public: 383 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 384 bool reassociateFPRed) 385 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 386 reassociateFPReductions(reassociateFPRed) {} 387 388 LogicalResult 389 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 390 ConversionPatternRewriter &rewriter) const override { 391 auto kind = reductionOp.kind(); 392 Type eltType = reductionOp.dest().getType(); 393 Type llvmType = typeConverter->convertType(eltType); 394 Value operand = adaptor.getOperands()[0]; 395 if (eltType.isIntOrIndex()) { 396 // Integer reductions: add/mul/min/max/and/or/xor. 397 if (kind == "add") 398 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp, 399 llvmType, operand); 400 else if (kind == "mul") 401 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp, 402 llvmType, operand); 403 else if (kind == "minui") 404 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 405 reductionOp, llvmType, operand); 406 else if (kind == "minsi") 407 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 408 reductionOp, llvmType, operand); 409 else if (kind == "maxui") 410 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 411 reductionOp, llvmType, operand); 412 else if (kind == "maxsi") 413 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 414 reductionOp, llvmType, operand); 415 else if (kind == "and") 416 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp, 417 llvmType, operand); 418 else if (kind == "or") 419 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp, 420 llvmType, operand); 421 else if (kind == "xor") 422 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp, 423 llvmType, operand); 424 else 425 return failure(); 426 return success(); 427 } 428 429 if (!eltType.isa<FloatType>()) 430 return failure(); 431 432 // Floating-point reductions: add/mul/min/max 433 if (kind == "add") { 434 // Optional accumulator (or zero). 435 Value acc = adaptor.getOperands().size() > 1 436 ? adaptor.getOperands()[1] 437 : rewriter.create<LLVM::ConstantOp>( 438 reductionOp->getLoc(), llvmType, 439 rewriter.getZeroAttr(eltType)); 440 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 441 reductionOp, llvmType, acc, operand, 442 rewriter.getBoolAttr(reassociateFPReductions)); 443 } else if (kind == "mul") { 444 // Optional accumulator (or one). 445 Value acc = adaptor.getOperands().size() > 1 446 ? adaptor.getOperands()[1] 447 : rewriter.create<LLVM::ConstantOp>( 448 reductionOp->getLoc(), llvmType, 449 rewriter.getFloatAttr(eltType, 1.0)); 450 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 451 reductionOp, llvmType, acc, operand, 452 rewriter.getBoolAttr(reassociateFPReductions)); 453 } else if (kind == "minf") 454 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 455 // NaNs/-0.0/+0.0 in the same way. 456 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp, 457 llvmType, operand); 458 else if (kind == "maxf") 459 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 460 // NaNs/-0.0/+0.0 in the same way. 461 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp, 462 llvmType, operand); 463 else 464 return failure(); 465 return success(); 466 } 467 468 private: 469 const bool reassociateFPReductions; 470 }; 471 472 class VectorShuffleOpConversion 473 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 474 public: 475 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 476 477 LogicalResult 478 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 479 ConversionPatternRewriter &rewriter) const override { 480 auto loc = shuffleOp->getLoc(); 481 auto v1Type = shuffleOp.getV1VectorType(); 482 auto v2Type = shuffleOp.getV2VectorType(); 483 auto vectorType = shuffleOp.getVectorType(); 484 Type llvmType = typeConverter->convertType(vectorType); 485 auto maskArrayAttr = shuffleOp.mask(); 486 487 // Bail if result type cannot be lowered. 488 if (!llvmType) 489 return failure(); 490 491 // Get rank and dimension sizes. 492 int64_t rank = vectorType.getRank(); 493 assert(v1Type.getRank() == rank); 494 assert(v2Type.getRank() == rank); 495 int64_t v1Dim = v1Type.getDimSize(0); 496 497 // For rank 1, where both operands have *exactly* the same vector type, 498 // there is direct shuffle support in LLVM. Use it! 499 if (rank == 1 && v1Type == v2Type) { 500 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 501 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 502 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 503 return success(); 504 } 505 506 // For all other cases, insert the individual values individually. 507 Type eltType; 508 if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>()) 509 eltType = arrayType.getElementType(); 510 else 511 eltType = llvmType.cast<VectorType>().getElementType(); 512 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 513 int64_t insPos = 0; 514 for (auto en : llvm::enumerate(maskArrayAttr)) { 515 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 516 Value value = adaptor.v1(); 517 if (extPos >= v1Dim) { 518 extPos -= v1Dim; 519 value = adaptor.v2(); 520 } 521 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 522 eltType, rank, extPos); 523 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 524 llvmType, rank, insPos++); 525 } 526 rewriter.replaceOp(shuffleOp, insert); 527 return success(); 528 } 529 }; 530 531 class VectorExtractElementOpConversion 532 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 533 public: 534 using ConvertOpToLLVMPattern< 535 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 536 537 LogicalResult 538 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 539 ConversionPatternRewriter &rewriter) const override { 540 auto vectorType = extractEltOp.getVectorType(); 541 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 542 543 // Bail if result type cannot be lowered. 544 if (!llvmType) 545 return failure(); 546 547 if (vectorType.getRank() == 0) { 548 Location loc = extractEltOp.getLoc(); 549 auto idxType = rewriter.getIndexType(); 550 auto zero = rewriter.create<LLVM::ConstantOp>( 551 loc, typeConverter->convertType(idxType), 552 rewriter.getIntegerAttr(idxType, 0)); 553 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 554 extractEltOp, llvmType, adaptor.vector(), zero); 555 return success(); 556 } 557 558 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 559 extractEltOp, llvmType, adaptor.vector(), adaptor.position()); 560 return success(); 561 } 562 }; 563 564 class VectorExtractOpConversion 565 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 566 public: 567 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 568 569 LogicalResult 570 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 571 ConversionPatternRewriter &rewriter) const override { 572 auto loc = extractOp->getLoc(); 573 auto vectorType = extractOp.getVectorType(); 574 auto resultType = extractOp.getResult().getType(); 575 auto llvmResultType = typeConverter->convertType(resultType); 576 auto positionArrayAttr = extractOp.position(); 577 578 // Bail if result type cannot be lowered. 579 if (!llvmResultType) 580 return failure(); 581 582 // Extract entire vector. Should be handled by folder, but just to be safe. 583 if (positionArrayAttr.empty()) { 584 rewriter.replaceOp(extractOp, adaptor.vector()); 585 return success(); 586 } 587 588 // One-shot extraction of vector from array (only requires extractvalue). 589 if (resultType.isa<VectorType>()) { 590 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 591 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 592 rewriter.replaceOp(extractOp, extracted); 593 return success(); 594 } 595 596 // Potential extraction of 1-D vector from array. 597 auto *context = extractOp->getContext(); 598 Value extracted = adaptor.vector(); 599 auto positionAttrs = positionArrayAttr.getValue(); 600 if (positionAttrs.size() > 1) { 601 auto oneDVectorType = reducedVectorTypeBack(vectorType); 602 auto nMinusOnePositionAttrs = 603 ArrayAttr::get(context, positionAttrs.drop_back()); 604 extracted = rewriter.create<LLVM::ExtractValueOp>( 605 loc, typeConverter->convertType(oneDVectorType), extracted, 606 nMinusOnePositionAttrs); 607 } 608 609 // Remaining extraction of element from 1-D LLVM vector 610 auto position = positionAttrs.back().cast<IntegerAttr>(); 611 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 612 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 613 extracted = 614 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 615 rewriter.replaceOp(extractOp, extracted); 616 617 return success(); 618 } 619 }; 620 621 /// Conversion pattern that turns a vector.fma on a 1-D vector 622 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 623 /// This does not match vectors of n >= 2 rank. 624 /// 625 /// Example: 626 /// ``` 627 /// vector.fma %a, %a, %a : vector<8xf32> 628 /// ``` 629 /// is converted to: 630 /// ``` 631 /// llvm.intr.fmuladd %va, %va, %va: 632 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 633 /// -> !llvm."<8 x f32>"> 634 /// ``` 635 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 636 public: 637 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 638 639 LogicalResult 640 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 641 ConversionPatternRewriter &rewriter) const override { 642 VectorType vType = fmaOp.getVectorType(); 643 if (vType.getRank() != 1) 644 return failure(); 645 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), 646 adaptor.rhs(), adaptor.acc()); 647 return success(); 648 } 649 }; 650 651 class VectorInsertElementOpConversion 652 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 653 public: 654 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 655 656 LogicalResult 657 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 658 ConversionPatternRewriter &rewriter) const override { 659 auto vectorType = insertEltOp.getDestVectorType(); 660 auto llvmType = typeConverter->convertType(vectorType); 661 662 // Bail if result type cannot be lowered. 663 if (!llvmType) 664 return failure(); 665 666 if (vectorType.getRank() == 0) { 667 Location loc = insertEltOp.getLoc(); 668 auto idxType = rewriter.getIndexType(); 669 auto zero = rewriter.create<LLVM::ConstantOp>( 670 loc, typeConverter->convertType(idxType), 671 rewriter.getIntegerAttr(idxType, 0)); 672 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 673 insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero); 674 return success(); 675 } 676 677 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 678 insertEltOp, llvmType, adaptor.dest(), adaptor.source(), 679 adaptor.position()); 680 return success(); 681 } 682 }; 683 684 class VectorInsertOpConversion 685 : public ConvertOpToLLVMPattern<vector::InsertOp> { 686 public: 687 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 688 689 LogicalResult 690 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 691 ConversionPatternRewriter &rewriter) const override { 692 auto loc = insertOp->getLoc(); 693 auto sourceType = insertOp.getSourceType(); 694 auto destVectorType = insertOp.getDestVectorType(); 695 auto llvmResultType = typeConverter->convertType(destVectorType); 696 auto positionArrayAttr = insertOp.position(); 697 698 // Bail if result type cannot be lowered. 699 if (!llvmResultType) 700 return failure(); 701 702 // Overwrite entire vector with value. Should be handled by folder, but 703 // just to be safe. 704 if (positionArrayAttr.empty()) { 705 rewriter.replaceOp(insertOp, adaptor.source()); 706 return success(); 707 } 708 709 // One-shot insertion of a vector into an array (only requires insertvalue). 710 if (sourceType.isa<VectorType>()) { 711 Value inserted = rewriter.create<LLVM::InsertValueOp>( 712 loc, llvmResultType, adaptor.dest(), adaptor.source(), 713 positionArrayAttr); 714 rewriter.replaceOp(insertOp, inserted); 715 return success(); 716 } 717 718 // Potential extraction of 1-D vector from array. 719 auto *context = insertOp->getContext(); 720 Value extracted = adaptor.dest(); 721 auto positionAttrs = positionArrayAttr.getValue(); 722 auto position = positionAttrs.back().cast<IntegerAttr>(); 723 auto oneDVectorType = destVectorType; 724 if (positionAttrs.size() > 1) { 725 oneDVectorType = reducedVectorTypeBack(destVectorType); 726 auto nMinusOnePositionAttrs = 727 ArrayAttr::get(context, positionAttrs.drop_back()); 728 extracted = rewriter.create<LLVM::ExtractValueOp>( 729 loc, typeConverter->convertType(oneDVectorType), extracted, 730 nMinusOnePositionAttrs); 731 } 732 733 // Insertion of an element into a 1-D LLVM vector. 734 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 735 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 736 Value inserted = rewriter.create<LLVM::InsertElementOp>( 737 loc, typeConverter->convertType(oneDVectorType), extracted, 738 adaptor.source(), constant); 739 740 // Potential insertion of resulting 1-D vector into array. 741 if (positionAttrs.size() > 1) { 742 auto nMinusOnePositionAttrs = 743 ArrayAttr::get(context, positionAttrs.drop_back()); 744 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 745 adaptor.dest(), inserted, 746 nMinusOnePositionAttrs); 747 } 748 749 rewriter.replaceOp(insertOp, inserted); 750 return success(); 751 } 752 }; 753 754 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 755 /// 756 /// Example: 757 /// ``` 758 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 759 /// ``` 760 /// is rewritten into: 761 /// ``` 762 /// %r = splat %f0: vector<2x4xf32> 763 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 764 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 765 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 766 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 767 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 768 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 769 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 770 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 771 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 772 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 773 /// // %r3 holds the final value. 774 /// ``` 775 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 776 public: 777 using OpRewritePattern<FMAOp>::OpRewritePattern; 778 779 void initialize() { 780 // This pattern recursively unpacks one dimension at a time. The recursion 781 // bounded as the rank is strictly decreasing. 782 setHasBoundedRewriteRecursion(); 783 } 784 785 LogicalResult matchAndRewrite(FMAOp op, 786 PatternRewriter &rewriter) const override { 787 auto vType = op.getVectorType(); 788 if (vType.getRank() < 2) 789 return failure(); 790 791 auto loc = op.getLoc(); 792 auto elemType = vType.getElementType(); 793 Value zero = rewriter.create<arith::ConstantOp>( 794 loc, elemType, rewriter.getZeroAttr(elemType)); 795 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 796 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 797 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 798 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 799 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 800 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 801 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 802 } 803 rewriter.replaceOp(op, desc); 804 return success(); 805 } 806 }; 807 808 /// Returns the strides if the memory underlying `memRefType` has a contiguous 809 /// static layout. 810 static llvm::Optional<SmallVector<int64_t, 4>> 811 computeContiguousStrides(MemRefType memRefType) { 812 int64_t offset; 813 SmallVector<int64_t, 4> strides; 814 if (failed(getStridesAndOffset(memRefType, strides, offset))) 815 return None; 816 if (!strides.empty() && strides.back() != 1) 817 return None; 818 // If no layout or identity layout, this is contiguous by definition. 819 if (memRefType.getLayout().isIdentity()) 820 return strides; 821 822 // Otherwise, we must determine contiguity form shapes. This can only ever 823 // work in static cases because MemRefType is underspecified to represent 824 // contiguous dynamic shapes in other ways than with just empty/identity 825 // layout. 826 auto sizes = memRefType.getShape(); 827 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 828 if (ShapedType::isDynamic(sizes[index + 1]) || 829 ShapedType::isDynamicStrideOrOffset(strides[index]) || 830 ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 831 return None; 832 if (strides[index] != strides[index + 1] * sizes[index + 1]) 833 return None; 834 } 835 return strides; 836 } 837 838 class VectorTypeCastOpConversion 839 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 840 public: 841 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 842 843 LogicalResult 844 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 845 ConversionPatternRewriter &rewriter) const override { 846 auto loc = castOp->getLoc(); 847 MemRefType sourceMemRefType = 848 castOp.getOperand().getType().cast<MemRefType>(); 849 MemRefType targetMemRefType = castOp.getType(); 850 851 // Only static shape casts supported atm. 852 if (!sourceMemRefType.hasStaticShape() || 853 !targetMemRefType.hasStaticShape()) 854 return failure(); 855 856 auto llvmSourceDescriptorTy = 857 adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>(); 858 if (!llvmSourceDescriptorTy) 859 return failure(); 860 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 861 862 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 863 .dyn_cast_or_null<LLVM::LLVMStructType>(); 864 if (!llvmTargetDescriptorTy) 865 return failure(); 866 867 // Only contiguous source buffers supported atm. 868 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 869 if (!sourceStrides) 870 return failure(); 871 auto targetStrides = computeContiguousStrides(targetMemRefType); 872 if (!targetStrides) 873 return failure(); 874 // Only support static strides for now, regardless of contiguity. 875 if (llvm::any_of(*targetStrides, [](int64_t stride) { 876 return ShapedType::isDynamicStrideOrOffset(stride); 877 })) 878 return failure(); 879 880 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 881 882 // Create descriptor. 883 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 884 Type llvmTargetElementTy = desc.getElementPtrType(); 885 // Set allocated ptr. 886 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 887 allocated = 888 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 889 desc.setAllocatedPtr(rewriter, loc, allocated); 890 // Set aligned ptr. 891 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 892 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 893 desc.setAlignedPtr(rewriter, loc, ptr); 894 // Fill offset 0. 895 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 896 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 897 desc.setOffset(rewriter, loc, zero); 898 899 // Fill size and stride descriptors in memref. 900 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 901 int64_t index = indexedSize.index(); 902 auto sizeAttr = 903 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 904 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 905 desc.setSize(rewriter, loc, index, size); 906 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 907 (*targetStrides)[index]); 908 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 909 desc.setStride(rewriter, loc, index, stride); 910 } 911 912 rewriter.replaceOp(castOp, {desc}); 913 return success(); 914 } 915 }; 916 917 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 918 public: 919 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 920 921 // Proof-of-concept lowering implementation that relies on a small 922 // runtime support library, which only needs to provide a few 923 // printing methods (single value for all data types, opening/closing 924 // bracket, comma, newline). The lowering fully unrolls a vector 925 // in terms of these elementary printing operations. The advantage 926 // of this approach is that the library can remain unaware of all 927 // low-level implementation details of vectors while still supporting 928 // output of any shaped and dimensioned vector. Due to full unrolling, 929 // this approach is less suited for very large vectors though. 930 // 931 // TODO: rely solely on libc in future? something else? 932 // 933 LogicalResult 934 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 935 ConversionPatternRewriter &rewriter) const override { 936 Type printType = printOp.getPrintType(); 937 938 if (typeConverter->convertType(printType) == nullptr) 939 return failure(); 940 941 // Make sure element type has runtime support. 942 PrintConversion conversion = PrintConversion::None; 943 VectorType vectorType = printType.dyn_cast<VectorType>(); 944 Type eltType = vectorType ? vectorType.getElementType() : printType; 945 Operation *printer; 946 if (eltType.isF32()) { 947 printer = 948 LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>()); 949 } else if (eltType.isF64()) { 950 printer = 951 LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>()); 952 } else if (eltType.isIndex()) { 953 printer = 954 LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>()); 955 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 956 // Integers need a zero or sign extension on the operand 957 // (depending on the source type) as well as a signed or 958 // unsigned print method. Up to 64-bit is supported. 959 unsigned width = intTy.getWidth(); 960 if (intTy.isUnsigned()) { 961 if (width <= 64) { 962 if (width < 64) 963 conversion = PrintConversion::ZeroExt64; 964 printer = LLVM::lookupOrCreatePrintU64Fn( 965 printOp->getParentOfType<ModuleOp>()); 966 } else { 967 return failure(); 968 } 969 } else { 970 assert(intTy.isSignless() || intTy.isSigned()); 971 if (width <= 64) { 972 // Note that we *always* zero extend booleans (1-bit integers), 973 // so that true/false is printed as 1/0 rather than -1/0. 974 if (width == 1) 975 conversion = PrintConversion::ZeroExt64; 976 else if (width < 64) 977 conversion = PrintConversion::SignExt64; 978 printer = LLVM::lookupOrCreatePrintI64Fn( 979 printOp->getParentOfType<ModuleOp>()); 980 } else { 981 return failure(); 982 } 983 } 984 } else { 985 return failure(); 986 } 987 988 // Unroll vector into elementary print calls. 989 int64_t rank = vectorType ? vectorType.getRank() : 0; 990 emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, 991 conversion); 992 emitCall(rewriter, printOp->getLoc(), 993 LLVM::lookupOrCreatePrintNewlineFn( 994 printOp->getParentOfType<ModuleOp>())); 995 rewriter.eraseOp(printOp); 996 return success(); 997 } 998 999 private: 1000 enum class PrintConversion { 1001 // clang-format off 1002 None, 1003 ZeroExt64, 1004 SignExt64 1005 // clang-format on 1006 }; 1007 1008 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1009 Value value, VectorType vectorType, Operation *printer, 1010 int64_t rank, PrintConversion conversion) const { 1011 Location loc = op->getLoc(); 1012 if (rank == 0) { 1013 switch (conversion) { 1014 case PrintConversion::ZeroExt64: 1015 value = rewriter.create<arith::ExtUIOp>( 1016 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1017 break; 1018 case PrintConversion::SignExt64: 1019 value = rewriter.create<arith::ExtSIOp>( 1020 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1021 break; 1022 case PrintConversion::None: 1023 break; 1024 } 1025 emitCall(rewriter, loc, printer, value); 1026 return; 1027 } 1028 1029 emitCall(rewriter, loc, 1030 LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>())); 1031 Operation *printComma = 1032 LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>()); 1033 int64_t dim = vectorType.getDimSize(0); 1034 for (int64_t d = 0; d < dim; ++d) { 1035 auto reducedType = 1036 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1037 auto llvmType = typeConverter->convertType( 1038 rank > 1 ? reducedType : vectorType.getElementType()); 1039 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1040 llvmType, rank, d); 1041 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1042 conversion); 1043 if (d != dim - 1) 1044 emitCall(rewriter, loc, printComma); 1045 } 1046 emitCall(rewriter, loc, 1047 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1048 } 1049 1050 // Helper to emit a call. 1051 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1052 Operation *ref, ValueRange params = ValueRange()) { 1053 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1054 params); 1055 } 1056 }; 1057 1058 } // namespace 1059 1060 /// Populate the given list with patterns that convert from Vector to LLVM. 1061 void mlir::populateVectorToLLVMConversionPatterns( 1062 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1063 bool reassociateFPReductions) { 1064 MLIRContext *ctx = converter.getDialect()->getContext(); 1065 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1066 populateVectorInsertExtractStridedSliceTransforms(patterns); 1067 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1068 patterns 1069 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1070 VectorExtractElementOpConversion, VectorExtractOpConversion, 1071 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1072 VectorInsertOpConversion, VectorPrintOpConversion, 1073 VectorTypeCastOpConversion, 1074 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1075 VectorLoadStoreConversion<vector::MaskedLoadOp, 1076 vector::MaskedLoadOpAdaptor>, 1077 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1078 VectorLoadStoreConversion<vector::MaskedStoreOp, 1079 vector::MaskedStoreOpAdaptor>, 1080 VectorGatherOpConversion, VectorScatterOpConversion, 1081 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>( 1082 converter); 1083 // Transfer ops with rank > 1 are handled by VectorToSCF. 1084 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1085 } 1086 1087 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1088 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1089 patterns.add<VectorMatmulOpConversion>(converter); 1090 patterns.add<VectorFlatTransposeOpConversion>(converter); 1091 } 1092