1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Vector/VectorTransforms.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/Support/MathExtras.h" 20 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 using namespace mlir::vector; 25 26 // Helper to reduce vector type by one rank at front. 27 static VectorType reducedVectorTypeFront(VectorType tp) { 28 assert((tp.getRank() > 1) && "unlowerable vector type"); 29 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 30 } 31 32 // Helper to reduce vector type by *all* but one rank at back. 33 static VectorType reducedVectorTypeBack(VectorType tp) { 34 assert((tp.getRank() > 1) && "unlowerable vector type"); 35 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 36 } 37 38 // Helper that picks the proper sequence for inserting. 39 static Value insertOne(ConversionPatternRewriter &rewriter, 40 LLVMTypeConverter &typeConverter, Location loc, 41 Value val1, Value val2, Type llvmType, int64_t rank, 42 int64_t pos) { 43 assert(rank > 0 && "0-D vector corner case should have been handled already"); 44 if (rank == 1) { 45 auto idxType = rewriter.getIndexType(); 46 auto constant = rewriter.create<LLVM::ConstantOp>( 47 loc, typeConverter.convertType(idxType), 48 rewriter.getIntegerAttr(idxType, pos)); 49 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 50 constant); 51 } 52 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 53 rewriter.getI64ArrayAttr(pos)); 54 } 55 56 // Helper that picks the proper sequence for extracting. 57 static Value extractOne(ConversionPatternRewriter &rewriter, 58 LLVMTypeConverter &typeConverter, Location loc, 59 Value val, Type llvmType, int64_t rank, int64_t pos) { 60 if (rank <= 1) { 61 auto idxType = rewriter.getIndexType(); 62 auto constant = rewriter.create<LLVM::ConstantOp>( 63 loc, typeConverter.convertType(idxType), 64 rewriter.getIntegerAttr(idxType, pos)); 65 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 66 constant); 67 } 68 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 69 rewriter.getI64ArrayAttr(pos)); 70 } 71 72 // Helper that returns data layout alignment of a memref. 73 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 74 MemRefType memrefType, unsigned &align) { 75 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 76 if (!elementTy) 77 return failure(); 78 79 // TODO: this should use the MLIR data layout when it becomes available and 80 // stop depending on translation. 81 llvm::LLVMContext llvmContext; 82 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 83 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 84 return success(); 85 } 86 87 // Add an index vector component to a base pointer. This almost always succeeds 88 // unless the last stride is non-unit or the memory space is not zero. 89 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 90 Location loc, Value memref, Value base, 91 Value index, MemRefType memRefType, 92 VectorType vType, Value &ptrs) { 93 int64_t offset; 94 SmallVector<int64_t, 4> strides; 95 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 96 if (failed(successStrides) || strides.back() != 1 || 97 memRefType.getMemorySpaceAsInt() != 0) 98 return failure(); 99 auto pType = MemRefDescriptor(memref).getElementPtrType(); 100 auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); 101 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index); 102 return success(); 103 } 104 105 // Casts a strided element pointer to a vector pointer. The vector pointer 106 // will be in the same address space as the incoming memref type. 107 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 108 Value ptr, MemRefType memRefType, Type vt) { 109 auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt()); 110 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 111 } 112 113 namespace { 114 115 /// Conversion pattern for a vector.bitcast. 116 class VectorBitCastOpConversion 117 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 118 public: 119 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 120 121 LogicalResult 122 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 123 ConversionPatternRewriter &rewriter) const override { 124 // Only 0-D and 1-D vectors can be lowered to LLVM. 125 VectorType resultTy = bitCastOp.getResultVectorType(); 126 if (resultTy.getRank() > 1) 127 return failure(); 128 Type newResultTy = typeConverter->convertType(resultTy); 129 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 130 adaptor.getOperands()[0]); 131 return success(); 132 } 133 }; 134 135 /// Conversion pattern for a vector.matrix_multiply. 136 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 137 class VectorMatmulOpConversion 138 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 139 public: 140 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 141 142 LogicalResult 143 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 144 ConversionPatternRewriter &rewriter) const override { 145 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 146 matmulOp, typeConverter->convertType(matmulOp.res().getType()), 147 adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), 148 matmulOp.lhs_columns(), matmulOp.rhs_columns()); 149 return success(); 150 } 151 }; 152 153 /// Conversion pattern for a vector.flat_transpose. 154 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 155 class VectorFlatTransposeOpConversion 156 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 157 public: 158 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 159 160 LogicalResult 161 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 162 ConversionPatternRewriter &rewriter) const override { 163 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 164 transOp, typeConverter->convertType(transOp.res().getType()), 165 adaptor.matrix(), transOp.rows(), transOp.columns()); 166 return success(); 167 } 168 }; 169 170 /// Overloaded utility that replaces a vector.load, vector.store, 171 /// vector.maskedload and vector.maskedstore with their respective LLVM 172 /// couterparts. 173 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 174 vector::LoadOpAdaptor adaptor, 175 VectorType vectorTy, Value ptr, unsigned align, 176 ConversionPatternRewriter &rewriter) { 177 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align); 178 } 179 180 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 181 vector::MaskedLoadOpAdaptor adaptor, 182 VectorType vectorTy, Value ptr, unsigned align, 183 ConversionPatternRewriter &rewriter) { 184 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 185 loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align); 186 } 187 188 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 189 vector::StoreOpAdaptor adaptor, 190 VectorType vectorTy, Value ptr, unsigned align, 191 ConversionPatternRewriter &rewriter) { 192 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(), 193 ptr, align); 194 } 195 196 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 197 vector::MaskedStoreOpAdaptor adaptor, 198 VectorType vectorTy, Value ptr, unsigned align, 199 ConversionPatternRewriter &rewriter) { 200 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 201 storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align); 202 } 203 204 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 205 /// vector.maskedstore. 206 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 207 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 208 public: 209 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 210 211 LogicalResult 212 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 213 typename LoadOrStoreOp::Adaptor adaptor, 214 ConversionPatternRewriter &rewriter) const override { 215 // Only 1-D vectors can be lowered to LLVM. 216 VectorType vectorTy = loadOrStoreOp.getVectorType(); 217 if (vectorTy.getRank() > 1) 218 return failure(); 219 220 auto loc = loadOrStoreOp->getLoc(); 221 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 222 223 // Resolve alignment. 224 unsigned align; 225 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 226 return failure(); 227 228 // Resolve address. 229 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 230 .template cast<VectorType>(); 231 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(), 232 adaptor.indices(), rewriter); 233 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); 234 235 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 236 return success(); 237 } 238 }; 239 240 /// Conversion pattern for a vector.gather. 241 class VectorGatherOpConversion 242 : public ConvertOpToLLVMPattern<vector::GatherOp> { 243 public: 244 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 245 246 LogicalResult 247 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 248 ConversionPatternRewriter &rewriter) const override { 249 auto loc = gather->getLoc(); 250 MemRefType memRefType = gather.getMemRefType(); 251 252 // Resolve alignment. 253 unsigned align; 254 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 255 return failure(); 256 257 // Resolve address. 258 Value ptrs; 259 VectorType vType = gather.getVectorType(); 260 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 261 adaptor.indices(), rewriter); 262 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, 263 adaptor.index_vec(), memRefType, vType, ptrs))) 264 return failure(); 265 266 // Replace with the gather intrinsic. 267 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 268 gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 269 adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 270 return success(); 271 } 272 }; 273 274 /// Conversion pattern for a vector.scatter. 275 class VectorScatterOpConversion 276 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 277 public: 278 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 279 280 LogicalResult 281 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 282 ConversionPatternRewriter &rewriter) const override { 283 auto loc = scatter->getLoc(); 284 MemRefType memRefType = scatter.getMemRefType(); 285 286 // Resolve alignment. 287 unsigned align; 288 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 289 return failure(); 290 291 // Resolve address. 292 Value ptrs; 293 VectorType vType = scatter.getVectorType(); 294 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 295 adaptor.indices(), rewriter); 296 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, 297 adaptor.index_vec(), memRefType, vType, ptrs))) 298 return failure(); 299 300 // Replace with the scatter intrinsic. 301 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 302 scatter, adaptor.valueToStore(), ptrs, adaptor.mask(), 303 rewriter.getI32IntegerAttr(align)); 304 return success(); 305 } 306 }; 307 308 /// Conversion pattern for a vector.expandload. 309 class VectorExpandLoadOpConversion 310 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 311 public: 312 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 313 314 LogicalResult 315 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 316 ConversionPatternRewriter &rewriter) const override { 317 auto loc = expand->getLoc(); 318 MemRefType memRefType = expand.getMemRefType(); 319 320 // Resolve address. 321 auto vtype = typeConverter->convertType(expand.getVectorType()); 322 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 323 adaptor.indices(), rewriter); 324 325 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 326 expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru()); 327 return success(); 328 } 329 }; 330 331 /// Conversion pattern for a vector.compressstore. 332 class VectorCompressStoreOpConversion 333 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 334 public: 335 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 336 337 LogicalResult 338 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 339 ConversionPatternRewriter &rewriter) const override { 340 auto loc = compress->getLoc(); 341 MemRefType memRefType = compress.getMemRefType(); 342 343 // Resolve address. 344 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 345 adaptor.indices(), rewriter); 346 347 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 348 compress, adaptor.valueToStore(), ptr, adaptor.mask()); 349 return success(); 350 } 351 }; 352 353 /// Conversion pattern for all vector reductions. 354 class VectorReductionOpConversion 355 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 356 public: 357 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 358 bool reassociateFPRed) 359 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 360 reassociateFPReductions(reassociateFPRed) {} 361 362 LogicalResult 363 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 364 ConversionPatternRewriter &rewriter) const override { 365 auto kind = reductionOp.kind(); 366 Type eltType = reductionOp.dest().getType(); 367 Type llvmType = typeConverter->convertType(eltType); 368 Value operand = adaptor.getOperands()[0]; 369 if (eltType.isIntOrIndex()) { 370 // Integer reductions: add/mul/min/max/and/or/xor. 371 if (kind == "add") 372 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp, 373 llvmType, operand); 374 else if (kind == "mul") 375 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp, 376 llvmType, operand); 377 else if (kind == "minui") 378 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 379 reductionOp, llvmType, operand); 380 else if (kind == "minsi") 381 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 382 reductionOp, llvmType, operand); 383 else if (kind == "maxui") 384 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 385 reductionOp, llvmType, operand); 386 else if (kind == "maxsi") 387 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 388 reductionOp, llvmType, operand); 389 else if (kind == "and") 390 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp, 391 llvmType, operand); 392 else if (kind == "or") 393 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp, 394 llvmType, operand); 395 else if (kind == "xor") 396 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp, 397 llvmType, operand); 398 else 399 return failure(); 400 return success(); 401 } 402 403 if (!eltType.isa<FloatType>()) 404 return failure(); 405 406 // Floating-point reductions: add/mul/min/max 407 if (kind == "add") { 408 // Optional accumulator (or zero). 409 Value acc = adaptor.getOperands().size() > 1 410 ? adaptor.getOperands()[1] 411 : rewriter.create<LLVM::ConstantOp>( 412 reductionOp->getLoc(), llvmType, 413 rewriter.getZeroAttr(eltType)); 414 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 415 reductionOp, llvmType, acc, operand, 416 rewriter.getBoolAttr(reassociateFPReductions)); 417 } else if (kind == "mul") { 418 // Optional accumulator (or one). 419 Value acc = adaptor.getOperands().size() > 1 420 ? adaptor.getOperands()[1] 421 : rewriter.create<LLVM::ConstantOp>( 422 reductionOp->getLoc(), llvmType, 423 rewriter.getFloatAttr(eltType, 1.0)); 424 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 425 reductionOp, llvmType, acc, operand, 426 rewriter.getBoolAttr(reassociateFPReductions)); 427 } else if (kind == "minf") 428 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 429 // NaNs/-0.0/+0.0 in the same way. 430 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp, 431 llvmType, operand); 432 else if (kind == "maxf") 433 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 434 // NaNs/-0.0/+0.0 in the same way. 435 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp, 436 llvmType, operand); 437 else 438 return failure(); 439 return success(); 440 } 441 442 private: 443 const bool reassociateFPReductions; 444 }; 445 446 class VectorShuffleOpConversion 447 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 448 public: 449 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 450 451 LogicalResult 452 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 453 ConversionPatternRewriter &rewriter) const override { 454 auto loc = shuffleOp->getLoc(); 455 auto v1Type = shuffleOp.getV1VectorType(); 456 auto v2Type = shuffleOp.getV2VectorType(); 457 auto vectorType = shuffleOp.getVectorType(); 458 Type llvmType = typeConverter->convertType(vectorType); 459 auto maskArrayAttr = shuffleOp.mask(); 460 461 // Bail if result type cannot be lowered. 462 if (!llvmType) 463 return failure(); 464 465 // Get rank and dimension sizes. 466 int64_t rank = vectorType.getRank(); 467 assert(v1Type.getRank() == rank); 468 assert(v2Type.getRank() == rank); 469 int64_t v1Dim = v1Type.getDimSize(0); 470 471 // For rank 1, where both operands have *exactly* the same vector type, 472 // there is direct shuffle support in LLVM. Use it! 473 if (rank == 1 && v1Type == v2Type) { 474 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 475 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 476 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 477 return success(); 478 } 479 480 // For all other cases, insert the individual values individually. 481 Type eltType; 482 if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>()) 483 eltType = arrayType.getElementType(); 484 else 485 eltType = llvmType.cast<VectorType>().getElementType(); 486 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 487 int64_t insPos = 0; 488 for (auto en : llvm::enumerate(maskArrayAttr)) { 489 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 490 Value value = adaptor.v1(); 491 if (extPos >= v1Dim) { 492 extPos -= v1Dim; 493 value = adaptor.v2(); 494 } 495 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 496 eltType, rank, extPos); 497 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 498 llvmType, rank, insPos++); 499 } 500 rewriter.replaceOp(shuffleOp, insert); 501 return success(); 502 } 503 }; 504 505 class VectorExtractElementOpConversion 506 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 507 public: 508 using ConvertOpToLLVMPattern< 509 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 510 511 LogicalResult 512 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 513 ConversionPatternRewriter &rewriter) const override { 514 auto vectorType = extractEltOp.getVectorType(); 515 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 516 517 // Bail if result type cannot be lowered. 518 if (!llvmType) 519 return failure(); 520 521 if (vectorType.getRank() == 0) { 522 Location loc = extractEltOp.getLoc(); 523 auto idxType = rewriter.getIndexType(); 524 auto zero = rewriter.create<LLVM::ConstantOp>( 525 loc, typeConverter->convertType(idxType), 526 rewriter.getIntegerAttr(idxType, 0)); 527 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 528 extractEltOp, llvmType, adaptor.vector(), zero); 529 return success(); 530 } 531 532 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 533 extractEltOp, llvmType, adaptor.vector(), adaptor.position()); 534 return success(); 535 } 536 }; 537 538 class VectorExtractOpConversion 539 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 540 public: 541 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 542 543 LogicalResult 544 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 545 ConversionPatternRewriter &rewriter) const override { 546 auto loc = extractOp->getLoc(); 547 auto vectorType = extractOp.getVectorType(); 548 auto resultType = extractOp.getResult().getType(); 549 auto llvmResultType = typeConverter->convertType(resultType); 550 auto positionArrayAttr = extractOp.position(); 551 552 // Bail if result type cannot be lowered. 553 if (!llvmResultType) 554 return failure(); 555 556 // Extract entire vector. Should be handled by folder, but just to be safe. 557 if (positionArrayAttr.empty()) { 558 rewriter.replaceOp(extractOp, adaptor.vector()); 559 return success(); 560 } 561 562 // One-shot extraction of vector from array (only requires extractvalue). 563 if (resultType.isa<VectorType>()) { 564 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 565 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 566 rewriter.replaceOp(extractOp, extracted); 567 return success(); 568 } 569 570 // Potential extraction of 1-D vector from array. 571 auto *context = extractOp->getContext(); 572 Value extracted = adaptor.vector(); 573 auto positionAttrs = positionArrayAttr.getValue(); 574 if (positionAttrs.size() > 1) { 575 auto oneDVectorType = reducedVectorTypeBack(vectorType); 576 auto nMinusOnePositionAttrs = 577 ArrayAttr::get(context, positionAttrs.drop_back()); 578 extracted = rewriter.create<LLVM::ExtractValueOp>( 579 loc, typeConverter->convertType(oneDVectorType), extracted, 580 nMinusOnePositionAttrs); 581 } 582 583 // Remaining extraction of element from 1-D LLVM vector 584 auto position = positionAttrs.back().cast<IntegerAttr>(); 585 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 586 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 587 extracted = 588 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 589 rewriter.replaceOp(extractOp, extracted); 590 591 return success(); 592 } 593 }; 594 595 /// Conversion pattern that turns a vector.fma on a 1-D vector 596 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 597 /// This does not match vectors of n >= 2 rank. 598 /// 599 /// Example: 600 /// ``` 601 /// vector.fma %a, %a, %a : vector<8xf32> 602 /// ``` 603 /// is converted to: 604 /// ``` 605 /// llvm.intr.fmuladd %va, %va, %va: 606 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 607 /// -> !llvm."<8 x f32>"> 608 /// ``` 609 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 610 public: 611 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 612 613 LogicalResult 614 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 615 ConversionPatternRewriter &rewriter) const override { 616 VectorType vType = fmaOp.getVectorType(); 617 if (vType.getRank() != 1) 618 return failure(); 619 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), 620 adaptor.rhs(), adaptor.acc()); 621 return success(); 622 } 623 }; 624 625 class VectorInsertElementOpConversion 626 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 627 public: 628 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 629 630 LogicalResult 631 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 632 ConversionPatternRewriter &rewriter) const override { 633 auto vectorType = insertEltOp.getDestVectorType(); 634 auto llvmType = typeConverter->convertType(vectorType); 635 636 // Bail if result type cannot be lowered. 637 if (!llvmType) 638 return failure(); 639 640 if (vectorType.getRank() == 0) { 641 Location loc = insertEltOp.getLoc(); 642 auto idxType = rewriter.getIndexType(); 643 auto zero = rewriter.create<LLVM::ConstantOp>( 644 loc, typeConverter->convertType(idxType), 645 rewriter.getIntegerAttr(idxType, 0)); 646 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 647 insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero); 648 return success(); 649 } 650 651 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 652 insertEltOp, llvmType, adaptor.dest(), adaptor.source(), 653 adaptor.position()); 654 return success(); 655 } 656 }; 657 658 class VectorInsertOpConversion 659 : public ConvertOpToLLVMPattern<vector::InsertOp> { 660 public: 661 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 662 663 LogicalResult 664 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 665 ConversionPatternRewriter &rewriter) const override { 666 auto loc = insertOp->getLoc(); 667 auto sourceType = insertOp.getSourceType(); 668 auto destVectorType = insertOp.getDestVectorType(); 669 auto llvmResultType = typeConverter->convertType(destVectorType); 670 auto positionArrayAttr = insertOp.position(); 671 672 // Bail if result type cannot be lowered. 673 if (!llvmResultType) 674 return failure(); 675 676 // Overwrite entire vector with value. Should be handled by folder, but 677 // just to be safe. 678 if (positionArrayAttr.empty()) { 679 rewriter.replaceOp(insertOp, adaptor.source()); 680 return success(); 681 } 682 683 // One-shot insertion of a vector into an array (only requires insertvalue). 684 if (sourceType.isa<VectorType>()) { 685 Value inserted = rewriter.create<LLVM::InsertValueOp>( 686 loc, llvmResultType, adaptor.dest(), adaptor.source(), 687 positionArrayAttr); 688 rewriter.replaceOp(insertOp, inserted); 689 return success(); 690 } 691 692 // Potential extraction of 1-D vector from array. 693 auto *context = insertOp->getContext(); 694 Value extracted = adaptor.dest(); 695 auto positionAttrs = positionArrayAttr.getValue(); 696 auto position = positionAttrs.back().cast<IntegerAttr>(); 697 auto oneDVectorType = destVectorType; 698 if (positionAttrs.size() > 1) { 699 oneDVectorType = reducedVectorTypeBack(destVectorType); 700 auto nMinusOnePositionAttrs = 701 ArrayAttr::get(context, positionAttrs.drop_back()); 702 extracted = rewriter.create<LLVM::ExtractValueOp>( 703 loc, typeConverter->convertType(oneDVectorType), extracted, 704 nMinusOnePositionAttrs); 705 } 706 707 // Insertion of an element into a 1-D LLVM vector. 708 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 709 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 710 Value inserted = rewriter.create<LLVM::InsertElementOp>( 711 loc, typeConverter->convertType(oneDVectorType), extracted, 712 adaptor.source(), constant); 713 714 // Potential insertion of resulting 1-D vector into array. 715 if (positionAttrs.size() > 1) { 716 auto nMinusOnePositionAttrs = 717 ArrayAttr::get(context, positionAttrs.drop_back()); 718 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 719 adaptor.dest(), inserted, 720 nMinusOnePositionAttrs); 721 } 722 723 rewriter.replaceOp(insertOp, inserted); 724 return success(); 725 } 726 }; 727 728 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 729 /// 730 /// Example: 731 /// ``` 732 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 733 /// ``` 734 /// is rewritten into: 735 /// ``` 736 /// %r = splat %f0: vector<2x4xf32> 737 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 738 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 739 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 740 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 741 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 742 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 743 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 744 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 745 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 746 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 747 /// // %r3 holds the final value. 748 /// ``` 749 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 750 public: 751 using OpRewritePattern<FMAOp>::OpRewritePattern; 752 753 void initialize() { 754 // This pattern recursively unpacks one dimension at a time. The recursion 755 // bounded as the rank is strictly decreasing. 756 setHasBoundedRewriteRecursion(); 757 } 758 759 LogicalResult matchAndRewrite(FMAOp op, 760 PatternRewriter &rewriter) const override { 761 auto vType = op.getVectorType(); 762 if (vType.getRank() < 2) 763 return failure(); 764 765 auto loc = op.getLoc(); 766 auto elemType = vType.getElementType(); 767 Value zero = rewriter.create<arith::ConstantOp>( 768 loc, elemType, rewriter.getZeroAttr(elemType)); 769 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 770 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 771 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 772 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 773 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 774 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 775 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 776 } 777 rewriter.replaceOp(op, desc); 778 return success(); 779 } 780 }; 781 782 /// Returns the strides if the memory underlying `memRefType` has a contiguous 783 /// static layout. 784 static llvm::Optional<SmallVector<int64_t, 4>> 785 computeContiguousStrides(MemRefType memRefType) { 786 int64_t offset; 787 SmallVector<int64_t, 4> strides; 788 if (failed(getStridesAndOffset(memRefType, strides, offset))) 789 return None; 790 if (!strides.empty() && strides.back() != 1) 791 return None; 792 // If no layout or identity layout, this is contiguous by definition. 793 if (memRefType.getLayout().isIdentity()) 794 return strides; 795 796 // Otherwise, we must determine contiguity form shapes. This can only ever 797 // work in static cases because MemRefType is underspecified to represent 798 // contiguous dynamic shapes in other ways than with just empty/identity 799 // layout. 800 auto sizes = memRefType.getShape(); 801 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 802 if (ShapedType::isDynamic(sizes[index + 1]) || 803 ShapedType::isDynamicStrideOrOffset(strides[index]) || 804 ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 805 return None; 806 if (strides[index] != strides[index + 1] * sizes[index + 1]) 807 return None; 808 } 809 return strides; 810 } 811 812 class VectorTypeCastOpConversion 813 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 814 public: 815 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 816 817 LogicalResult 818 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 819 ConversionPatternRewriter &rewriter) const override { 820 auto loc = castOp->getLoc(); 821 MemRefType sourceMemRefType = 822 castOp.getOperand().getType().cast<MemRefType>(); 823 MemRefType targetMemRefType = castOp.getType(); 824 825 // Only static shape casts supported atm. 826 if (!sourceMemRefType.hasStaticShape() || 827 !targetMemRefType.hasStaticShape()) 828 return failure(); 829 830 auto llvmSourceDescriptorTy = 831 adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>(); 832 if (!llvmSourceDescriptorTy) 833 return failure(); 834 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 835 836 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 837 .dyn_cast_or_null<LLVM::LLVMStructType>(); 838 if (!llvmTargetDescriptorTy) 839 return failure(); 840 841 // Only contiguous source buffers supported atm. 842 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 843 if (!sourceStrides) 844 return failure(); 845 auto targetStrides = computeContiguousStrides(targetMemRefType); 846 if (!targetStrides) 847 return failure(); 848 // Only support static strides for now, regardless of contiguity. 849 if (llvm::any_of(*targetStrides, [](int64_t stride) { 850 return ShapedType::isDynamicStrideOrOffset(stride); 851 })) 852 return failure(); 853 854 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 855 856 // Create descriptor. 857 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 858 Type llvmTargetElementTy = desc.getElementPtrType(); 859 // Set allocated ptr. 860 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 861 allocated = 862 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 863 desc.setAllocatedPtr(rewriter, loc, allocated); 864 // Set aligned ptr. 865 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 866 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 867 desc.setAlignedPtr(rewriter, loc, ptr); 868 // Fill offset 0. 869 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 870 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 871 desc.setOffset(rewriter, loc, zero); 872 873 // Fill size and stride descriptors in memref. 874 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 875 int64_t index = indexedSize.index(); 876 auto sizeAttr = 877 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 878 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 879 desc.setSize(rewriter, loc, index, size); 880 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 881 (*targetStrides)[index]); 882 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 883 desc.setStride(rewriter, loc, index, stride); 884 } 885 886 rewriter.replaceOp(castOp, {desc}); 887 return success(); 888 } 889 }; 890 891 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 892 public: 893 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 894 895 // Proof-of-concept lowering implementation that relies on a small 896 // runtime support library, which only needs to provide a few 897 // printing methods (single value for all data types, opening/closing 898 // bracket, comma, newline). The lowering fully unrolls a vector 899 // in terms of these elementary printing operations. The advantage 900 // of this approach is that the library can remain unaware of all 901 // low-level implementation details of vectors while still supporting 902 // output of any shaped and dimensioned vector. Due to full unrolling, 903 // this approach is less suited for very large vectors though. 904 // 905 // TODO: rely solely on libc in future? something else? 906 // 907 LogicalResult 908 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 909 ConversionPatternRewriter &rewriter) const override { 910 Type printType = printOp.getPrintType(); 911 912 if (typeConverter->convertType(printType) == nullptr) 913 return failure(); 914 915 // Make sure element type has runtime support. 916 PrintConversion conversion = PrintConversion::None; 917 VectorType vectorType = printType.dyn_cast<VectorType>(); 918 Type eltType = vectorType ? vectorType.getElementType() : printType; 919 Operation *printer; 920 if (eltType.isF32()) { 921 printer = 922 LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>()); 923 } else if (eltType.isF64()) { 924 printer = 925 LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>()); 926 } else if (eltType.isIndex()) { 927 printer = 928 LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>()); 929 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 930 // Integers need a zero or sign extension on the operand 931 // (depending on the source type) as well as a signed or 932 // unsigned print method. Up to 64-bit is supported. 933 unsigned width = intTy.getWidth(); 934 if (intTy.isUnsigned()) { 935 if (width <= 64) { 936 if (width < 64) 937 conversion = PrintConversion::ZeroExt64; 938 printer = LLVM::lookupOrCreatePrintU64Fn( 939 printOp->getParentOfType<ModuleOp>()); 940 } else { 941 return failure(); 942 } 943 } else { 944 assert(intTy.isSignless() || intTy.isSigned()); 945 if (width <= 64) { 946 // Note that we *always* zero extend booleans (1-bit integers), 947 // so that true/false is printed as 1/0 rather than -1/0. 948 if (width == 1) 949 conversion = PrintConversion::ZeroExt64; 950 else if (width < 64) 951 conversion = PrintConversion::SignExt64; 952 printer = LLVM::lookupOrCreatePrintI64Fn( 953 printOp->getParentOfType<ModuleOp>()); 954 } else { 955 return failure(); 956 } 957 } 958 } else { 959 return failure(); 960 } 961 962 // Unroll vector into elementary print calls. 963 int64_t rank = vectorType ? vectorType.getRank() : 0; 964 Type type = vectorType ? vectorType : eltType; 965 emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank, 966 conversion); 967 emitCall(rewriter, printOp->getLoc(), 968 LLVM::lookupOrCreatePrintNewlineFn( 969 printOp->getParentOfType<ModuleOp>())); 970 rewriter.eraseOp(printOp); 971 return success(); 972 } 973 974 private: 975 enum class PrintConversion { 976 // clang-format off 977 None, 978 ZeroExt64, 979 SignExt64 980 // clang-format on 981 }; 982 983 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 984 Value value, Type type, Operation *printer, int64_t rank, 985 PrintConversion conversion) const { 986 VectorType vectorType = type.dyn_cast<VectorType>(); 987 Location loc = op->getLoc(); 988 if (!vectorType) { 989 assert(rank == 0 && "The scalar case expects rank == 0"); 990 switch (conversion) { 991 case PrintConversion::ZeroExt64: 992 value = rewriter.create<arith::ExtUIOp>( 993 loc, value, IntegerType::get(rewriter.getContext(), 64)); 994 break; 995 case PrintConversion::SignExt64: 996 value = rewriter.create<arith::ExtSIOp>( 997 loc, value, IntegerType::get(rewriter.getContext(), 64)); 998 break; 999 case PrintConversion::None: 1000 break; 1001 } 1002 emitCall(rewriter, loc, printer, value); 1003 return; 1004 } 1005 1006 emitCall(rewriter, loc, 1007 LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>())); 1008 Operation *printComma = 1009 LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>()); 1010 1011 if (rank <= 1) { 1012 auto reducedType = vectorType.getElementType(); 1013 auto llvmType = typeConverter->convertType(reducedType); 1014 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1015 for (int64_t d = 0; d < dim; ++d) { 1016 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1017 llvmType, /*rank=*/0, /*pos=*/d); 1018 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1019 conversion); 1020 if (d != dim - 1) 1021 emitCall(rewriter, loc, printComma); 1022 } 1023 emitCall( 1024 rewriter, loc, 1025 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1026 return; 1027 } 1028 1029 int64_t dim = vectorType.getDimSize(0); 1030 for (int64_t d = 0; d < dim; ++d) { 1031 auto reducedType = reducedVectorTypeFront(vectorType); 1032 auto llvmType = typeConverter->convertType(reducedType); 1033 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1034 llvmType, rank, d); 1035 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1036 conversion); 1037 if (d != dim - 1) 1038 emitCall(rewriter, loc, printComma); 1039 } 1040 emitCall(rewriter, loc, 1041 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1042 } 1043 1044 // Helper to emit a call. 1045 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1046 Operation *ref, ValueRange params = ValueRange()) { 1047 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1048 params); 1049 } 1050 }; 1051 1052 } // namespace 1053 1054 /// Populate the given list with patterns that convert from Vector to LLVM. 1055 void mlir::populateVectorToLLVMConversionPatterns( 1056 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1057 bool reassociateFPReductions) { 1058 MLIRContext *ctx = converter.getDialect()->getContext(); 1059 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1060 populateVectorInsertExtractStridedSliceTransforms(patterns); 1061 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1062 patterns 1063 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1064 VectorExtractElementOpConversion, VectorExtractOpConversion, 1065 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1066 VectorInsertOpConversion, VectorPrintOpConversion, 1067 VectorTypeCastOpConversion, 1068 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1069 VectorLoadStoreConversion<vector::MaskedLoadOp, 1070 vector::MaskedLoadOpAdaptor>, 1071 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1072 VectorLoadStoreConversion<vector::MaskedStoreOp, 1073 vector::MaskedStoreOpAdaptor>, 1074 VectorGatherOpConversion, VectorScatterOpConversion, 1075 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>( 1076 converter); 1077 // Transfer ops with rank > 1 are handled by VectorToSCF. 1078 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1079 } 1080 1081 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1082 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1083 patterns.add<VectorMatmulOpConversion>(converter); 1084 patterns.add<VectorFlatTransposeOpConversion>(converter); 1085 } 1086