1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 12 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Vector/VectorOps.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/Support/MathExtras.h" 19 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 25 // Helper to reduce vector type by one rank at front. 26 static VectorType reducedVectorTypeFront(VectorType tp) { 27 assert((tp.getRank() > 1) && "unlowerable vector type"); 28 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 29 } 30 31 // Helper to reduce vector type by *all* but one rank at back. 32 static VectorType reducedVectorTypeBack(VectorType tp) { 33 assert((tp.getRank() > 1) && "unlowerable vector type"); 34 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 35 } 36 37 // Helper that picks the proper sequence for inserting. 38 static Value insertOne(ConversionPatternRewriter &rewriter, 39 LLVMTypeConverter &typeConverter, Location loc, 40 Value val1, Value val2, Type llvmType, int64_t rank, 41 int64_t pos) { 42 if (rank == 1) { 43 auto idxType = rewriter.getIndexType(); 44 auto constant = rewriter.create<LLVM::ConstantOp>( 45 loc, typeConverter.convertType(idxType), 46 rewriter.getIntegerAttr(idxType, pos)); 47 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 48 constant); 49 } 50 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 51 rewriter.getI64ArrayAttr(pos)); 52 } 53 54 // Helper that picks the proper sequence for inserting. 55 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 56 Value into, int64_t offset) { 57 auto vectorType = into.getType().cast<VectorType>(); 58 if (vectorType.getRank() > 1) 59 return rewriter.create<InsertOp>(loc, from, into, offset); 60 return rewriter.create<vector::InsertElementOp>( 61 loc, vectorType, from, into, 62 rewriter.create<ConstantIndexOp>(loc, offset)); 63 } 64 65 // Helper that picks the proper sequence for extracting. 66 static Value extractOne(ConversionPatternRewriter &rewriter, 67 LLVMTypeConverter &typeConverter, Location loc, 68 Value val, Type llvmType, int64_t rank, int64_t pos) { 69 if (rank == 1) { 70 auto idxType = rewriter.getIndexType(); 71 auto constant = rewriter.create<LLVM::ConstantOp>( 72 loc, typeConverter.convertType(idxType), 73 rewriter.getIntegerAttr(idxType, pos)); 74 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 75 constant); 76 } 77 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 78 rewriter.getI64ArrayAttr(pos)); 79 } 80 81 // Helper that picks the proper sequence for extracting. 82 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 83 int64_t offset) { 84 auto vectorType = vector.getType().cast<VectorType>(); 85 if (vectorType.getRank() > 1) 86 return rewriter.create<ExtractOp>(loc, vector, offset); 87 return rewriter.create<vector::ExtractElementOp>( 88 loc, vectorType.getElementType(), vector, 89 rewriter.create<ConstantIndexOp>(loc, offset)); 90 } 91 92 // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 93 // TODO: Better support for attribute subtype forwarding + slicing. 94 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 95 unsigned dropFront = 0, 96 unsigned dropBack = 0) { 97 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 98 auto range = arrayAttr.getAsRange<IntegerAttr>(); 99 SmallVector<int64_t, 4> res; 100 res.reserve(arrayAttr.size() - dropFront - dropBack); 101 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 102 it != eit; ++it) 103 res.push_back((*it).getValue().getSExtValue()); 104 return res; 105 } 106 107 // Helper that returns data layout alignment of a memref. 108 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 109 MemRefType memrefType, unsigned &align) { 110 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 111 if (!elementTy) 112 return failure(); 113 114 // TODO: this should use the MLIR data layout when it becomes available and 115 // stop depending on translation. 116 llvm::LLVMContext llvmContext; 117 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 118 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 119 return success(); 120 } 121 122 // Return the minimal alignment value that satisfies all the AssumeAlignment 123 // uses of `value`. If no such uses exist, return 1. 124 static unsigned getAssumedAlignment(Value value) { 125 unsigned align = 1; 126 for (auto &u : value.getUses()) { 127 Operation *owner = u.getOwner(); 128 if (auto op = dyn_cast<memref::AssumeAlignmentOp>(owner)) 129 align = mlir::lcm(align, op.alignment()); 130 } 131 return align; 132 } 133 // Helper that returns data layout alignment of a memref associated with a 134 // transfer op, including additional information from assume_alignment calls 135 // on the source of the transfer 136 LogicalResult getTransferOpAlignment(LLVMTypeConverter &typeConverter, 137 VectorTransferOpInterface xfer, 138 unsigned &align) { 139 if (failed(getMemRefAlignment( 140 typeConverter, xfer.getShapedType().cast<MemRefType>(), align))) 141 return failure(); 142 align = std::max(align, getAssumedAlignment(xfer.source())); 143 return success(); 144 } 145 146 // Helper that returns data layout alignment of a memref associated with a 147 // load, store, scatter, or gather op, including additional information from 148 // assume_alignment calls on the source of the transfer 149 template <class OpAdaptor> 150 LogicalResult getMemRefOpAlignment(LLVMTypeConverter &typeConverter, 151 OpAdaptor op, unsigned &align) { 152 if (failed(getMemRefAlignment(typeConverter, op.getMemRefType(), align))) 153 return failure(); 154 align = std::max(align, getAssumedAlignment(op.base())); 155 return success(); 156 } 157 158 // Add an index vector component to a base pointer. This almost always succeeds 159 // unless the last stride is non-unit or the memory space is not zero. 160 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 161 Location loc, Value memref, Value base, 162 Value index, MemRefType memRefType, 163 VectorType vType, Value &ptrs) { 164 int64_t offset; 165 SmallVector<int64_t, 4> strides; 166 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 167 if (failed(successStrides) || strides.back() != 1 || 168 memRefType.getMemorySpaceAsInt() != 0) 169 return failure(); 170 auto pType = MemRefDescriptor(memref).getElementPtrType(); 171 auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); 172 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index); 173 return success(); 174 } 175 176 // Casts a strided element pointer to a vector pointer. The vector pointer 177 // will be in the same address space as the incoming memref type. 178 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 179 Value ptr, MemRefType memRefType, Type vt) { 180 auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt()); 181 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 182 } 183 184 static LogicalResult 185 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 186 LLVMTypeConverter &typeConverter, Location loc, 187 TransferReadOp xferOp, 188 ArrayRef<Value> operands, Value dataPtr) { 189 unsigned align; 190 if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) 191 return failure(); 192 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 193 return success(); 194 } 195 196 static LogicalResult 197 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 198 LLVMTypeConverter &typeConverter, Location loc, 199 TransferReadOp xferOp, ArrayRef<Value> operands, 200 Value dataPtr, Value mask) { 201 Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 202 if (!vecTy) 203 return failure(); 204 205 auto adaptor = TransferReadOpAdaptor(operands, xferOp->getAttrDictionary()); 206 Value fill = rewriter.create<SplatOp>(loc, vecTy, adaptor.padding()); 207 208 unsigned align; 209 if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) 210 return failure(); 211 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 212 xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 213 rewriter.getI32IntegerAttr(align)); 214 return success(); 215 } 216 217 static LogicalResult 218 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 219 LLVMTypeConverter &typeConverter, Location loc, 220 TransferWriteOp xferOp, 221 ArrayRef<Value> operands, Value dataPtr) { 222 unsigned align; 223 if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) 224 return failure(); 225 auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); 226 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 227 align); 228 return success(); 229 } 230 231 static LogicalResult 232 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 233 LLVMTypeConverter &typeConverter, Location loc, 234 TransferWriteOp xferOp, ArrayRef<Value> operands, 235 Value dataPtr, Value mask) { 236 unsigned align; 237 if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) 238 return failure(); 239 240 auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); 241 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 242 xferOp, adaptor.vector(), dataPtr, mask, 243 rewriter.getI32IntegerAttr(align)); 244 return success(); 245 } 246 247 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 248 ArrayRef<Value> operands) { 249 return TransferReadOpAdaptor(operands, xferOp->getAttrDictionary()); 250 } 251 252 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 253 ArrayRef<Value> operands) { 254 return TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); 255 } 256 257 namespace { 258 259 /// Conversion pattern for a vector.bitcast. 260 class VectorBitCastOpConversion 261 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 262 public: 263 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 264 265 LogicalResult 266 matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef<Value> operands, 267 ConversionPatternRewriter &rewriter) const override { 268 // Only 1-D vectors can be lowered to LLVM. 269 VectorType resultTy = bitCastOp.getType(); 270 if (resultTy.getRank() != 1) 271 return failure(); 272 Type newResultTy = typeConverter->convertType(resultTy); 273 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 274 operands[0]); 275 return success(); 276 } 277 }; 278 279 /// Conversion pattern for a vector.matrix_multiply. 280 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 281 class VectorMatmulOpConversion 282 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 283 public: 284 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 285 286 LogicalResult 287 matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands, 288 ConversionPatternRewriter &rewriter) const override { 289 auto adaptor = vector::MatmulOpAdaptor(operands); 290 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 291 matmulOp, typeConverter->convertType(matmulOp.res().getType()), 292 adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), 293 matmulOp.lhs_columns(), matmulOp.rhs_columns()); 294 return success(); 295 } 296 }; 297 298 /// Conversion pattern for a vector.flat_transpose. 299 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 300 class VectorFlatTransposeOpConversion 301 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 302 public: 303 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 304 305 LogicalResult 306 matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands, 307 ConversionPatternRewriter &rewriter) const override { 308 auto adaptor = vector::FlatTransposeOpAdaptor(operands); 309 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 310 transOp, typeConverter->convertType(transOp.res().getType()), 311 adaptor.matrix(), transOp.rows(), transOp.columns()); 312 return success(); 313 } 314 }; 315 316 /// Overloaded utility that replaces a vector.load, vector.store, 317 /// vector.maskedload and vector.maskedstore with their respective LLVM 318 /// couterparts. 319 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 320 vector::LoadOpAdaptor adaptor, 321 VectorType vectorTy, Value ptr, unsigned align, 322 ConversionPatternRewriter &rewriter) { 323 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align); 324 } 325 326 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 327 vector::MaskedLoadOpAdaptor adaptor, 328 VectorType vectorTy, Value ptr, unsigned align, 329 ConversionPatternRewriter &rewriter) { 330 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 331 loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align); 332 } 333 334 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 335 vector::StoreOpAdaptor adaptor, 336 VectorType vectorTy, Value ptr, unsigned align, 337 ConversionPatternRewriter &rewriter) { 338 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(), 339 ptr, align); 340 } 341 342 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 343 vector::MaskedStoreOpAdaptor adaptor, 344 VectorType vectorTy, Value ptr, unsigned align, 345 ConversionPatternRewriter &rewriter) { 346 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 347 storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align); 348 } 349 350 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 351 /// vector.maskedstore. 352 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 353 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 354 public: 355 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 356 357 LogicalResult 358 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef<Value> operands, 359 ConversionPatternRewriter &rewriter) const override { 360 // Only 1-D vectors can be lowered to LLVM. 361 VectorType vectorTy = loadOrStoreOp.getVectorType(); 362 if (vectorTy.getRank() > 1) 363 return failure(); 364 365 auto loc = loadOrStoreOp->getLoc(); 366 auto adaptor = LoadOrStoreOpAdaptor(operands); 367 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 368 369 // Resolve alignment. 370 unsigned align; 371 if (failed(getMemRefOpAlignment(*this->getTypeConverter(), loadOrStoreOp, 372 align))) 373 return failure(); 374 375 // Resolve address. 376 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 377 .template cast<VectorType>(); 378 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(), 379 adaptor.indices(), rewriter); 380 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); 381 382 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 383 return success(); 384 } 385 }; 386 387 /// Conversion pattern for a vector.gather. 388 class VectorGatherOpConversion 389 : public ConvertOpToLLVMPattern<vector::GatherOp> { 390 public: 391 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 392 393 LogicalResult 394 matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands, 395 ConversionPatternRewriter &rewriter) const override { 396 auto loc = gather->getLoc(); 397 auto adaptor = vector::GatherOpAdaptor(operands); 398 MemRefType memRefType = gather.getMemRefType(); 399 400 // Resolve alignment. 401 unsigned align; 402 if (failed(getMemRefOpAlignment(*getTypeConverter(), gather, align))) 403 return failure(); 404 405 // Resolve address. 406 Value ptrs; 407 VectorType vType = gather.getVectorType(); 408 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 409 adaptor.indices(), rewriter); 410 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, 411 adaptor.index_vec(), memRefType, vType, ptrs))) 412 return failure(); 413 414 // Replace with the gather intrinsic. 415 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 416 gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 417 adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 418 return success(); 419 } 420 }; 421 422 /// Conversion pattern for a vector.scatter. 423 class VectorScatterOpConversion 424 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 425 public: 426 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 427 428 LogicalResult 429 matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands, 430 ConversionPatternRewriter &rewriter) const override { 431 auto loc = scatter->getLoc(); 432 auto adaptor = vector::ScatterOpAdaptor(operands); 433 MemRefType memRefType = scatter.getMemRefType(); 434 435 // Resolve alignment. 436 unsigned align; 437 if (failed(getMemRefOpAlignment(*getTypeConverter(), scatter, align))) 438 return failure(); 439 440 // Resolve address. 441 Value ptrs; 442 VectorType vType = scatter.getVectorType(); 443 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 444 adaptor.indices(), rewriter); 445 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, 446 adaptor.index_vec(), memRefType, vType, ptrs))) 447 return failure(); 448 449 // Replace with the scatter intrinsic. 450 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 451 scatter, adaptor.valueToStore(), ptrs, adaptor.mask(), 452 rewriter.getI32IntegerAttr(align)); 453 return success(); 454 } 455 }; 456 457 /// Conversion pattern for a vector.expandload. 458 class VectorExpandLoadOpConversion 459 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 460 public: 461 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 462 463 LogicalResult 464 matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands, 465 ConversionPatternRewriter &rewriter) const override { 466 auto loc = expand->getLoc(); 467 auto adaptor = vector::ExpandLoadOpAdaptor(operands); 468 MemRefType memRefType = expand.getMemRefType(); 469 470 // Resolve address. 471 auto vtype = typeConverter->convertType(expand.getVectorType()); 472 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 473 adaptor.indices(), rewriter); 474 475 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 476 expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru()); 477 return success(); 478 } 479 }; 480 481 /// Conversion pattern for a vector.compressstore. 482 class VectorCompressStoreOpConversion 483 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 484 public: 485 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 486 487 LogicalResult 488 matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands, 489 ConversionPatternRewriter &rewriter) const override { 490 auto loc = compress->getLoc(); 491 auto adaptor = vector::CompressStoreOpAdaptor(operands); 492 MemRefType memRefType = compress.getMemRefType(); 493 494 // Resolve address. 495 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 496 adaptor.indices(), rewriter); 497 498 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 499 compress, adaptor.valueToStore(), ptr, adaptor.mask()); 500 return success(); 501 } 502 }; 503 504 /// Conversion pattern for all vector reductions. 505 class VectorReductionOpConversion 506 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 507 public: 508 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 509 bool reassociateFPRed) 510 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 511 reassociateFPReductions(reassociateFPRed) {} 512 513 LogicalResult 514 matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands, 515 ConversionPatternRewriter &rewriter) const override { 516 auto kind = reductionOp.kind(); 517 Type eltType = reductionOp.dest().getType(); 518 Type llvmType = typeConverter->convertType(eltType); 519 if (eltType.isIntOrIndex()) { 520 // Integer reductions: add/mul/min/max/and/or/xor. 521 if (kind == "add") 522 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>( 523 reductionOp, llvmType, operands[0]); 524 else if (kind == "mul") 525 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>( 526 reductionOp, llvmType, operands[0]); 527 else if (kind == "min" && 528 (eltType.isIndex() || eltType.isUnsignedInteger())) 529 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 530 reductionOp, llvmType, operands[0]); 531 else if (kind == "min") 532 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 533 reductionOp, llvmType, operands[0]); 534 else if (kind == "max" && 535 (eltType.isIndex() || eltType.isUnsignedInteger())) 536 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 537 reductionOp, llvmType, operands[0]); 538 else if (kind == "max") 539 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 540 reductionOp, llvmType, operands[0]); 541 else if (kind == "and") 542 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>( 543 reductionOp, llvmType, operands[0]); 544 else if (kind == "or") 545 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>( 546 reductionOp, llvmType, operands[0]); 547 else if (kind == "xor") 548 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>( 549 reductionOp, llvmType, operands[0]); 550 else 551 return failure(); 552 return success(); 553 } 554 555 if (!eltType.isa<FloatType>()) 556 return failure(); 557 558 // Floating-point reductions: add/mul/min/max 559 if (kind == "add") { 560 // Optional accumulator (or zero). 561 Value acc = operands.size() > 1 ? operands[1] 562 : rewriter.create<LLVM::ConstantOp>( 563 reductionOp->getLoc(), llvmType, 564 rewriter.getZeroAttr(eltType)); 565 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 566 reductionOp, llvmType, acc, operands[0], 567 rewriter.getBoolAttr(reassociateFPReductions)); 568 } else if (kind == "mul") { 569 // Optional accumulator (or one). 570 Value acc = operands.size() > 1 571 ? operands[1] 572 : rewriter.create<LLVM::ConstantOp>( 573 reductionOp->getLoc(), llvmType, 574 rewriter.getFloatAttr(eltType, 1.0)); 575 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 576 reductionOp, llvmType, acc, operands[0], 577 rewriter.getBoolAttr(reassociateFPReductions)); 578 } else if (kind == "min") 579 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>( 580 reductionOp, llvmType, operands[0]); 581 else if (kind == "max") 582 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>( 583 reductionOp, llvmType, operands[0]); 584 else 585 return failure(); 586 return success(); 587 } 588 589 private: 590 const bool reassociateFPReductions; 591 }; 592 593 class VectorShuffleOpConversion 594 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 595 public: 596 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 597 598 LogicalResult 599 matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands, 600 ConversionPatternRewriter &rewriter) const override { 601 auto loc = shuffleOp->getLoc(); 602 auto adaptor = vector::ShuffleOpAdaptor(operands); 603 auto v1Type = shuffleOp.getV1VectorType(); 604 auto v2Type = shuffleOp.getV2VectorType(); 605 auto vectorType = shuffleOp.getVectorType(); 606 Type llvmType = typeConverter->convertType(vectorType); 607 auto maskArrayAttr = shuffleOp.mask(); 608 609 // Bail if result type cannot be lowered. 610 if (!llvmType) 611 return failure(); 612 613 // Get rank and dimension sizes. 614 int64_t rank = vectorType.getRank(); 615 assert(v1Type.getRank() == rank); 616 assert(v2Type.getRank() == rank); 617 int64_t v1Dim = v1Type.getDimSize(0); 618 619 // For rank 1, where both operands have *exactly* the same vector type, 620 // there is direct shuffle support in LLVM. Use it! 621 if (rank == 1 && v1Type == v2Type) { 622 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 623 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 624 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 625 return success(); 626 } 627 628 // For all other cases, insert the individual values individually. 629 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 630 int64_t insPos = 0; 631 for (auto en : llvm::enumerate(maskArrayAttr)) { 632 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 633 Value value = adaptor.v1(); 634 if (extPos >= v1Dim) { 635 extPos -= v1Dim; 636 value = adaptor.v2(); 637 } 638 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 639 llvmType, rank, extPos); 640 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 641 llvmType, rank, insPos++); 642 } 643 rewriter.replaceOp(shuffleOp, insert); 644 return success(); 645 } 646 }; 647 648 class VectorExtractElementOpConversion 649 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 650 public: 651 using ConvertOpToLLVMPattern< 652 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 653 654 LogicalResult 655 matchAndRewrite(vector::ExtractElementOp extractEltOp, 656 ArrayRef<Value> operands, 657 ConversionPatternRewriter &rewriter) const override { 658 auto adaptor = vector::ExtractElementOpAdaptor(operands); 659 auto vectorType = extractEltOp.getVectorType(); 660 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 661 662 // Bail if result type cannot be lowered. 663 if (!llvmType) 664 return failure(); 665 666 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 667 extractEltOp, llvmType, adaptor.vector(), adaptor.position()); 668 return success(); 669 } 670 }; 671 672 class VectorExtractOpConversion 673 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 674 public: 675 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 676 677 LogicalResult 678 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 679 ConversionPatternRewriter &rewriter) const override { 680 auto loc = extractOp->getLoc(); 681 auto adaptor = vector::ExtractOpAdaptor(operands); 682 auto vectorType = extractOp.getVectorType(); 683 auto resultType = extractOp.getResult().getType(); 684 auto llvmResultType = typeConverter->convertType(resultType); 685 auto positionArrayAttr = extractOp.position(); 686 687 // Bail if result type cannot be lowered. 688 if (!llvmResultType) 689 return failure(); 690 691 // Extract entire vector. Should be handled by folder, but just to be safe. 692 if (positionArrayAttr.empty()) { 693 rewriter.replaceOp(extractOp, adaptor.vector()); 694 return success(); 695 } 696 697 // One-shot extraction of vector from array (only requires extractvalue). 698 if (resultType.isa<VectorType>()) { 699 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 700 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 701 rewriter.replaceOp(extractOp, extracted); 702 return success(); 703 } 704 705 // Potential extraction of 1-D vector from array. 706 auto *context = extractOp->getContext(); 707 Value extracted = adaptor.vector(); 708 auto positionAttrs = positionArrayAttr.getValue(); 709 if (positionAttrs.size() > 1) { 710 auto oneDVectorType = reducedVectorTypeBack(vectorType); 711 auto nMinusOnePositionAttrs = 712 ArrayAttr::get(context, positionAttrs.drop_back()); 713 extracted = rewriter.create<LLVM::ExtractValueOp>( 714 loc, typeConverter->convertType(oneDVectorType), extracted, 715 nMinusOnePositionAttrs); 716 } 717 718 // Remaining extraction of element from 1-D LLVM vector 719 auto position = positionAttrs.back().cast<IntegerAttr>(); 720 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 721 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 722 extracted = 723 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 724 rewriter.replaceOp(extractOp, extracted); 725 726 return success(); 727 } 728 }; 729 730 /// Conversion pattern that turns a vector.fma on a 1-D vector 731 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 732 /// This does not match vectors of n >= 2 rank. 733 /// 734 /// Example: 735 /// ``` 736 /// vector.fma %a, %a, %a : vector<8xf32> 737 /// ``` 738 /// is converted to: 739 /// ``` 740 /// llvm.intr.fmuladd %va, %va, %va: 741 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 742 /// -> !llvm."<8 x f32>"> 743 /// ``` 744 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 745 public: 746 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 747 748 LogicalResult 749 matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, 750 ConversionPatternRewriter &rewriter) const override { 751 auto adaptor = vector::FMAOpAdaptor(operands); 752 VectorType vType = fmaOp.getVectorType(); 753 if (vType.getRank() != 1) 754 return failure(); 755 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), 756 adaptor.rhs(), adaptor.acc()); 757 return success(); 758 } 759 }; 760 761 class VectorInsertElementOpConversion 762 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 763 public: 764 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 765 766 LogicalResult 767 matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands, 768 ConversionPatternRewriter &rewriter) const override { 769 auto adaptor = vector::InsertElementOpAdaptor(operands); 770 auto vectorType = insertEltOp.getDestVectorType(); 771 auto llvmType = typeConverter->convertType(vectorType); 772 773 // Bail if result type cannot be lowered. 774 if (!llvmType) 775 return failure(); 776 777 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 778 insertEltOp, llvmType, adaptor.dest(), adaptor.source(), 779 adaptor.position()); 780 return success(); 781 } 782 }; 783 784 class VectorInsertOpConversion 785 : public ConvertOpToLLVMPattern<vector::InsertOp> { 786 public: 787 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 788 789 LogicalResult 790 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 791 ConversionPatternRewriter &rewriter) const override { 792 auto loc = insertOp->getLoc(); 793 auto adaptor = vector::InsertOpAdaptor(operands); 794 auto sourceType = insertOp.getSourceType(); 795 auto destVectorType = insertOp.getDestVectorType(); 796 auto llvmResultType = typeConverter->convertType(destVectorType); 797 auto positionArrayAttr = insertOp.position(); 798 799 // Bail if result type cannot be lowered. 800 if (!llvmResultType) 801 return failure(); 802 803 // Overwrite entire vector with value. Should be handled by folder, but 804 // just to be safe. 805 if (positionArrayAttr.empty()) { 806 rewriter.replaceOp(insertOp, adaptor.source()); 807 return success(); 808 } 809 810 // One-shot insertion of a vector into an array (only requires insertvalue). 811 if (sourceType.isa<VectorType>()) { 812 Value inserted = rewriter.create<LLVM::InsertValueOp>( 813 loc, llvmResultType, adaptor.dest(), adaptor.source(), 814 positionArrayAttr); 815 rewriter.replaceOp(insertOp, inserted); 816 return success(); 817 } 818 819 // Potential extraction of 1-D vector from array. 820 auto *context = insertOp->getContext(); 821 Value extracted = adaptor.dest(); 822 auto positionAttrs = positionArrayAttr.getValue(); 823 auto position = positionAttrs.back().cast<IntegerAttr>(); 824 auto oneDVectorType = destVectorType; 825 if (positionAttrs.size() > 1) { 826 oneDVectorType = reducedVectorTypeBack(destVectorType); 827 auto nMinusOnePositionAttrs = 828 ArrayAttr::get(context, positionAttrs.drop_back()); 829 extracted = rewriter.create<LLVM::ExtractValueOp>( 830 loc, typeConverter->convertType(oneDVectorType), extracted, 831 nMinusOnePositionAttrs); 832 } 833 834 // Insertion of an element into a 1-D LLVM vector. 835 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 836 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 837 Value inserted = rewriter.create<LLVM::InsertElementOp>( 838 loc, typeConverter->convertType(oneDVectorType), extracted, 839 adaptor.source(), constant); 840 841 // Potential insertion of resulting 1-D vector into array. 842 if (positionAttrs.size() > 1) { 843 auto nMinusOnePositionAttrs = 844 ArrayAttr::get(context, positionAttrs.drop_back()); 845 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 846 adaptor.dest(), inserted, 847 nMinusOnePositionAttrs); 848 } 849 850 rewriter.replaceOp(insertOp, inserted); 851 return success(); 852 } 853 }; 854 855 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 856 /// 857 /// Example: 858 /// ``` 859 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 860 /// ``` 861 /// is rewritten into: 862 /// ``` 863 /// %r = splat %f0: vector<2x4xf32> 864 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 865 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 866 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 867 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 868 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 869 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 870 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 871 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 872 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 873 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 874 /// // %r3 holds the final value. 875 /// ``` 876 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 877 public: 878 using OpRewritePattern<FMAOp>::OpRewritePattern; 879 880 LogicalResult matchAndRewrite(FMAOp op, 881 PatternRewriter &rewriter) const override { 882 auto vType = op.getVectorType(); 883 if (vType.getRank() < 2) 884 return failure(); 885 886 auto loc = op.getLoc(); 887 auto elemType = vType.getElementType(); 888 Value zero = rewriter.create<ConstantOp>(loc, elemType, 889 rewriter.getZeroAttr(elemType)); 890 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 891 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 892 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 893 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 894 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 895 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 896 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 897 } 898 rewriter.replaceOp(op, desc); 899 return success(); 900 } 901 }; 902 903 // When ranks are different, InsertStridedSlice needs to extract a properly 904 // ranked vector from the destination vector into which to insert. This pattern 905 // only takes care of this part and forwards the rest of the conversion to 906 // another pattern that converts InsertStridedSlice for operands of the same 907 // rank. 908 // 909 // RewritePattern for InsertStridedSliceOp where source and destination vectors 910 // have different ranks. In this case: 911 // 1. the proper subvector is extracted from the destination vector 912 // 2. a new InsertStridedSlice op is created to insert the source in the 913 // destination subvector 914 // 3. the destination subvector is inserted back in the proper place 915 // 4. the op is replaced by the result of step 3. 916 // The new InsertStridedSlice from step 2. will be picked up by a 917 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 918 class VectorInsertStridedSliceOpDifferentRankRewritePattern 919 : public OpRewritePattern<InsertStridedSliceOp> { 920 public: 921 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 922 923 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 924 PatternRewriter &rewriter) const override { 925 auto srcType = op.getSourceVectorType(); 926 auto dstType = op.getDestVectorType(); 927 928 if (op.offsets().getValue().empty()) 929 return failure(); 930 931 auto loc = op.getLoc(); 932 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 933 assert(rankDiff >= 0); 934 if (rankDiff == 0) 935 return failure(); 936 937 int64_t rankRest = dstType.getRank() - rankDiff; 938 // Extract / insert the subvector of matching rank and InsertStridedSlice 939 // on it. 940 Value extracted = 941 rewriter.create<ExtractOp>(loc, op.dest(), 942 getI64SubArray(op.offsets(), /*dropFront=*/0, 943 /*dropBack=*/rankRest)); 944 // A different pattern will kick in for InsertStridedSlice with matching 945 // ranks. 946 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 947 loc, op.source(), extracted, 948 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 949 getI64SubArray(op.strides(), /*dropFront=*/0)); 950 rewriter.replaceOpWithNewOp<InsertOp>( 951 op, stridedSliceInnerOp.getResult(), op.dest(), 952 getI64SubArray(op.offsets(), /*dropFront=*/0, 953 /*dropBack=*/rankRest)); 954 return success(); 955 } 956 }; 957 958 // RewritePattern for InsertStridedSliceOp where source and destination vectors 959 // have the same rank. In this case, we reduce 960 // 1. the proper subvector is extracted from the destination vector 961 // 2. a new InsertStridedSlice op is created to insert the source in the 962 // destination subvector 963 // 3. the destination subvector is inserted back in the proper place 964 // 4. the op is replaced by the result of step 3. 965 // The new InsertStridedSlice from step 2. will be picked up by a 966 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 967 class VectorInsertStridedSliceOpSameRankRewritePattern 968 : public OpRewritePattern<InsertStridedSliceOp> { 969 public: 970 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 971 972 void initialize() { 973 // This pattern creates recursive InsertStridedSliceOp, but the recursion is 974 // bounded as the rank is strictly decreasing. 975 setHasBoundedRewriteRecursion(); 976 } 977 978 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 979 PatternRewriter &rewriter) const override { 980 auto srcType = op.getSourceVectorType(); 981 auto dstType = op.getDestVectorType(); 982 983 if (op.offsets().getValue().empty()) 984 return failure(); 985 986 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 987 assert(rankDiff >= 0); 988 if (rankDiff != 0) 989 return failure(); 990 991 if (srcType == dstType) { 992 rewriter.replaceOp(op, op.source()); 993 return success(); 994 } 995 996 int64_t offset = 997 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 998 int64_t size = srcType.getShape().front(); 999 int64_t stride = 1000 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1001 1002 auto loc = op.getLoc(); 1003 Value res = op.dest(); 1004 // For each slice of the source vector along the most major dimension. 1005 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1006 off += stride, ++idx) { 1007 // 1. extract the proper subvector (or element) from source 1008 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 1009 if (extractedSource.getType().isa<VectorType>()) { 1010 // 2. If we have a vector, extract the proper subvector from destination 1011 // Otherwise we are at the element level and no need to recurse. 1012 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 1013 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 1014 // smaller rank. 1015 extractedSource = rewriter.create<InsertStridedSliceOp>( 1016 loc, extractedSource, extractedDest, 1017 getI64SubArray(op.offsets(), /* dropFront=*/1), 1018 getI64SubArray(op.strides(), /* dropFront=*/1)); 1019 } 1020 // 4. Insert the extractedSource into the res vector. 1021 res = insertOne(rewriter, loc, extractedSource, res, off); 1022 } 1023 1024 rewriter.replaceOp(op, res); 1025 return success(); 1026 } 1027 }; 1028 1029 /// Return true if the last dimension of the MemRefType has unit stride. Also 1030 /// return true for memrefs with no strides. 1031 static bool isLastMemrefDimUnitStride(MemRefType type) { 1032 int64_t offset; 1033 SmallVector<int64_t> strides; 1034 auto successStrides = getStridesAndOffset(type, strides, offset); 1035 return succeeded(successStrides) && (strides.empty() || strides.back() == 1); 1036 } 1037 1038 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1039 /// static layout. 1040 static llvm::Optional<SmallVector<int64_t, 4>> 1041 computeContiguousStrides(MemRefType memRefType) { 1042 int64_t offset; 1043 SmallVector<int64_t, 4> strides; 1044 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1045 return None; 1046 if (!strides.empty() && strides.back() != 1) 1047 return None; 1048 // If no layout or identity layout, this is contiguous by definition. 1049 if (memRefType.getAffineMaps().empty() || 1050 memRefType.getAffineMaps().front().isIdentity()) 1051 return strides; 1052 1053 // Otherwise, we must determine contiguity form shapes. This can only ever 1054 // work in static cases because MemRefType is underspecified to represent 1055 // contiguous dynamic shapes in other ways than with just empty/identity 1056 // layout. 1057 auto sizes = memRefType.getShape(); 1058 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1059 if (ShapedType::isDynamic(sizes[index + 1]) || 1060 ShapedType::isDynamicStrideOrOffset(strides[index]) || 1061 ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 1062 return None; 1063 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1064 return None; 1065 } 1066 return strides; 1067 } 1068 1069 class VectorTypeCastOpConversion 1070 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1071 public: 1072 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1073 1074 LogicalResult 1075 matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands, 1076 ConversionPatternRewriter &rewriter) const override { 1077 auto loc = castOp->getLoc(); 1078 MemRefType sourceMemRefType = 1079 castOp.getOperand().getType().cast<MemRefType>(); 1080 MemRefType targetMemRefType = castOp.getType(); 1081 1082 // Only static shape casts supported atm. 1083 if (!sourceMemRefType.hasStaticShape() || 1084 !targetMemRefType.hasStaticShape()) 1085 return failure(); 1086 1087 auto llvmSourceDescriptorTy = 1088 operands[0].getType().dyn_cast<LLVM::LLVMStructType>(); 1089 if (!llvmSourceDescriptorTy) 1090 return failure(); 1091 MemRefDescriptor sourceMemRef(operands[0]); 1092 1093 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1094 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1095 if (!llvmTargetDescriptorTy) 1096 return failure(); 1097 1098 // Only contiguous source buffers supported atm. 1099 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1100 if (!sourceStrides) 1101 return failure(); 1102 auto targetStrides = computeContiguousStrides(targetMemRefType); 1103 if (!targetStrides) 1104 return failure(); 1105 // Only support static strides for now, regardless of contiguity. 1106 if (llvm::any_of(*targetStrides, [](int64_t stride) { 1107 return ShapedType::isDynamicStrideOrOffset(stride); 1108 })) 1109 return failure(); 1110 1111 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1112 1113 // Create descriptor. 1114 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1115 Type llvmTargetElementTy = desc.getElementPtrType(); 1116 // Set allocated ptr. 1117 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1118 allocated = 1119 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1120 desc.setAllocatedPtr(rewriter, loc, allocated); 1121 // Set aligned ptr. 1122 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1123 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1124 desc.setAlignedPtr(rewriter, loc, ptr); 1125 // Fill offset 0. 1126 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1127 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1128 desc.setOffset(rewriter, loc, zero); 1129 1130 // Fill size and stride descriptors in memref. 1131 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 1132 int64_t index = indexedSize.index(); 1133 auto sizeAttr = 1134 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1135 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1136 desc.setSize(rewriter, loc, index, size); 1137 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1138 (*targetStrides)[index]); 1139 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1140 desc.setStride(rewriter, loc, index, stride); 1141 } 1142 1143 rewriter.replaceOp(castOp, {desc}); 1144 return success(); 1145 } 1146 }; 1147 1148 /// Conversion pattern that converts a 1-D vector transfer read/write op into a 1149 /// a masked or unmasked read/write. 1150 template <typename ConcreteOp> 1151 class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> { 1152 public: 1153 using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern; 1154 1155 LogicalResult 1156 matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands, 1157 ConversionPatternRewriter &rewriter) const override { 1158 auto adaptor = getTransferOpAdapter(xferOp, operands); 1159 1160 if (xferOp.getVectorType().getRank() > 1 || xferOp.indices().empty()) 1161 return failure(); 1162 if (xferOp.permutation_map() != 1163 AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 1164 xferOp.getVectorType().getRank(), 1165 xferOp->getContext())) 1166 return failure(); 1167 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 1168 if (!memRefType) 1169 return failure(); 1170 // Last dimension must be contiguous. (Otherwise: Use VectorToSCF.) 1171 if (!isLastMemrefDimUnitStride(memRefType)) 1172 return failure(); 1173 // Out-of-bounds dims are handled by MaterializeTransferMask. 1174 if (xferOp.hasOutOfBoundsDim()) 1175 return failure(); 1176 1177 auto toLLVMTy = [&](Type t) { 1178 return this->getTypeConverter()->convertType(t); 1179 }; 1180 1181 Location loc = xferOp->getLoc(); 1182 1183 if (auto memrefVectorElementType = 1184 memRefType.getElementType().template dyn_cast<VectorType>()) { 1185 // Memref has vector element type. 1186 if (memrefVectorElementType.getElementType() != 1187 xferOp.getVectorType().getElementType()) 1188 return failure(); 1189 #ifndef NDEBUG 1190 // Check that memref vector type is a suffix of 'vectorType. 1191 unsigned memrefVecEltRank = memrefVectorElementType.getRank(); 1192 unsigned resultVecRank = xferOp.getVectorType().getRank(); 1193 assert(memrefVecEltRank <= resultVecRank); 1194 // TODO: Move this to isSuffix in Vector/Utils.h. 1195 unsigned rankOffset = resultVecRank - memrefVecEltRank; 1196 auto memrefVecEltShape = memrefVectorElementType.getShape(); 1197 auto resultVecShape = xferOp.getVectorType().getShape(); 1198 for (unsigned i = 0; i < memrefVecEltRank; ++i) 1199 assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && 1200 "memref vector element shape should match suffix of vector " 1201 "result shape."); 1202 #endif // ifndef NDEBUG 1203 } 1204 1205 // Get the source/dst address as an LLVM vector pointer. 1206 VectorType vtp = xferOp.getVectorType(); 1207 Value dataPtr = this->getStridedElementPtr( 1208 loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); 1209 Value vectorDataPtr = 1210 castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp)); 1211 1212 // Rewrite as an unmasked masked read / write. 1213 if (!xferOp.mask()) 1214 return replaceTransferOpWithLoadOrStore(rewriter, 1215 *this->getTypeConverter(), loc, 1216 xferOp, operands, vectorDataPtr); 1217 1218 // Rewrite as a masked read / write. 1219 return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, 1220 xferOp, operands, vectorDataPtr, 1221 xferOp.mask()); 1222 } 1223 }; 1224 1225 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1226 public: 1227 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1228 1229 // Proof-of-concept lowering implementation that relies on a small 1230 // runtime support library, which only needs to provide a few 1231 // printing methods (single value for all data types, opening/closing 1232 // bracket, comma, newline). The lowering fully unrolls a vector 1233 // in terms of these elementary printing operations. The advantage 1234 // of this approach is that the library can remain unaware of all 1235 // low-level implementation details of vectors while still supporting 1236 // output of any shaped and dimensioned vector. Due to full unrolling, 1237 // this approach is less suited for very large vectors though. 1238 // 1239 // TODO: rely solely on libc in future? something else? 1240 // 1241 LogicalResult 1242 matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands, 1243 ConversionPatternRewriter &rewriter) const override { 1244 auto adaptor = vector::PrintOpAdaptor(operands); 1245 Type printType = printOp.getPrintType(); 1246 1247 if (typeConverter->convertType(printType) == nullptr) 1248 return failure(); 1249 1250 // Make sure element type has runtime support. 1251 PrintConversion conversion = PrintConversion::None; 1252 VectorType vectorType = printType.dyn_cast<VectorType>(); 1253 Type eltType = vectorType ? vectorType.getElementType() : printType; 1254 Operation *printer; 1255 if (eltType.isF32()) { 1256 printer = 1257 LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>()); 1258 } else if (eltType.isF64()) { 1259 printer = 1260 LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>()); 1261 } else if (eltType.isIndex()) { 1262 printer = 1263 LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>()); 1264 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1265 // Integers need a zero or sign extension on the operand 1266 // (depending on the source type) as well as a signed or 1267 // unsigned print method. Up to 64-bit is supported. 1268 unsigned width = intTy.getWidth(); 1269 if (intTy.isUnsigned()) { 1270 if (width <= 64) { 1271 if (width < 64) 1272 conversion = PrintConversion::ZeroExt64; 1273 printer = LLVM::lookupOrCreatePrintU64Fn( 1274 printOp->getParentOfType<ModuleOp>()); 1275 } else { 1276 return failure(); 1277 } 1278 } else { 1279 assert(intTy.isSignless() || intTy.isSigned()); 1280 if (width <= 64) { 1281 // Note that we *always* zero extend booleans (1-bit integers), 1282 // so that true/false is printed as 1/0 rather than -1/0. 1283 if (width == 1) 1284 conversion = PrintConversion::ZeroExt64; 1285 else if (width < 64) 1286 conversion = PrintConversion::SignExt64; 1287 printer = LLVM::lookupOrCreatePrintI64Fn( 1288 printOp->getParentOfType<ModuleOp>()); 1289 } else { 1290 return failure(); 1291 } 1292 } 1293 } else { 1294 return failure(); 1295 } 1296 1297 // Unroll vector into elementary print calls. 1298 int64_t rank = vectorType ? vectorType.getRank() : 0; 1299 emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, 1300 conversion); 1301 emitCall(rewriter, printOp->getLoc(), 1302 LLVM::lookupOrCreatePrintNewlineFn( 1303 printOp->getParentOfType<ModuleOp>())); 1304 rewriter.eraseOp(printOp); 1305 return success(); 1306 } 1307 1308 private: 1309 enum class PrintConversion { 1310 // clang-format off 1311 None, 1312 ZeroExt64, 1313 SignExt64 1314 // clang-format on 1315 }; 1316 1317 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1318 Value value, VectorType vectorType, Operation *printer, 1319 int64_t rank, PrintConversion conversion) const { 1320 Location loc = op->getLoc(); 1321 if (rank == 0) { 1322 switch (conversion) { 1323 case PrintConversion::ZeroExt64: 1324 value = rewriter.create<ZeroExtendIOp>( 1325 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1326 break; 1327 case PrintConversion::SignExt64: 1328 value = rewriter.create<SignExtendIOp>( 1329 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1330 break; 1331 case PrintConversion::None: 1332 break; 1333 } 1334 emitCall(rewriter, loc, printer, value); 1335 return; 1336 } 1337 1338 emitCall(rewriter, loc, 1339 LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>())); 1340 Operation *printComma = 1341 LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>()); 1342 int64_t dim = vectorType.getDimSize(0); 1343 for (int64_t d = 0; d < dim; ++d) { 1344 auto reducedType = 1345 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1346 auto llvmType = typeConverter->convertType( 1347 rank > 1 ? reducedType : vectorType.getElementType()); 1348 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1349 llvmType, rank, d); 1350 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1351 conversion); 1352 if (d != dim - 1) 1353 emitCall(rewriter, loc, printComma); 1354 } 1355 emitCall(rewriter, loc, 1356 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1357 } 1358 1359 // Helper to emit a call. 1360 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1361 Operation *ref, ValueRange params = ValueRange()) { 1362 rewriter.create<LLVM::CallOp>(loc, TypeRange(), 1363 rewriter.getSymbolRefAttr(ref), params); 1364 } 1365 }; 1366 1367 /// Progressive lowering of ExtractStridedSliceOp to either: 1368 /// 1. express single offset extract as a direct shuffle. 1369 /// 2. extract + lower rank strided_slice + insert for the n-D case. 1370 class VectorExtractStridedSliceOpConversion 1371 : public OpRewritePattern<ExtractStridedSliceOp> { 1372 public: 1373 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 1374 1375 void initialize() { 1376 // This pattern creates recursive ExtractStridedSliceOp, but the recursion 1377 // is bounded as the rank is strictly decreasing. 1378 setHasBoundedRewriteRecursion(); 1379 } 1380 1381 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 1382 PatternRewriter &rewriter) const override { 1383 auto dstType = op.getType(); 1384 1385 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 1386 1387 int64_t offset = 1388 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1389 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 1390 int64_t stride = 1391 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1392 1393 auto loc = op.getLoc(); 1394 auto elemType = dstType.getElementType(); 1395 assert(elemType.isSignlessIntOrIndexOrFloat()); 1396 1397 // Single offset can be more efficiently shuffled. 1398 if (op.offsets().getValue().size() == 1) { 1399 SmallVector<int64_t, 4> offsets; 1400 offsets.reserve(size); 1401 for (int64_t off = offset, e = offset + size * stride; off < e; 1402 off += stride) 1403 offsets.push_back(off); 1404 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 1405 op.vector(), 1406 rewriter.getI64ArrayAttr(offsets)); 1407 return success(); 1408 } 1409 1410 // Extract/insert on a lower ranked extract strided slice op. 1411 Value zero = rewriter.create<ConstantOp>(loc, elemType, 1412 rewriter.getZeroAttr(elemType)); 1413 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 1414 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1415 off += stride, ++idx) { 1416 Value one = extractOne(rewriter, loc, op.vector(), off); 1417 Value extracted = rewriter.create<ExtractStridedSliceOp>( 1418 loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 1419 getI64SubArray(op.sizes(), /* dropFront=*/1), 1420 getI64SubArray(op.strides(), /* dropFront=*/1)); 1421 res = insertOne(rewriter, loc, extracted, res, idx); 1422 } 1423 rewriter.replaceOp(op, res); 1424 return success(); 1425 } 1426 }; 1427 1428 } // namespace 1429 1430 /// Populate the given list with patterns that convert from Vector to LLVM. 1431 void mlir::populateVectorToLLVMConversionPatterns( 1432 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1433 bool reassociateFPReductions) { 1434 MLIRContext *ctx = converter.getDialect()->getContext(); 1435 patterns.add<VectorFMAOpNDRewritePattern, 1436 VectorInsertStridedSliceOpDifferentRankRewritePattern, 1437 VectorInsertStridedSliceOpSameRankRewritePattern, 1438 VectorExtractStridedSliceOpConversion>(ctx); 1439 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1440 patterns 1441 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1442 VectorExtractElementOpConversion, VectorExtractOpConversion, 1443 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1444 VectorInsertOpConversion, VectorPrintOpConversion, 1445 VectorTypeCastOpConversion, 1446 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1447 VectorLoadStoreConversion<vector::MaskedLoadOp, 1448 vector::MaskedLoadOpAdaptor>, 1449 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1450 VectorLoadStoreConversion<vector::MaskedStoreOp, 1451 vector::MaskedStoreOpAdaptor>, 1452 VectorGatherOpConversion, VectorScatterOpConversion, 1453 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1454 VectorTransferConversion<TransferReadOp>, 1455 VectorTransferConversion<TransferWriteOp>>(converter); 1456 } 1457 1458 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1459 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1460 patterns.add<VectorMatmulOpConversion>(converter); 1461 patterns.add<VectorFlatTransposeOpConversion>(converter); 1462 } 1463