1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Vector/VectorOps.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/MLIRContext.h" 21 #include "mlir/IR/Module.h" 22 #include "mlir/IR/Operation.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/IR/StandardTypes.h" 25 #include "mlir/IR/Types.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 #include "mlir/Transforms/Passes.h" 28 #include "llvm/IR/DerivedTypes.h" 29 #include "llvm/IR/Module.h" 30 #include "llvm/IR/Type.h" 31 #include "llvm/Support/Allocator.h" 32 #include "llvm/Support/ErrorHandling.h" 33 34 using namespace mlir; 35 using namespace mlir::vector; 36 37 // Helper to reduce vector type by one rank at front. 38 static VectorType reducedVectorTypeFront(VectorType tp) { 39 assert((tp.getRank() > 1) && "unlowerable vector type"); 40 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 41 } 42 43 // Helper to reduce vector type by *all* but one rank at back. 44 static VectorType reducedVectorTypeBack(VectorType tp) { 45 assert((tp.getRank() > 1) && "unlowerable vector type"); 46 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 47 } 48 49 // Helper that picks the proper sequence for inserting. 50 static Value insertOne(ConversionPatternRewriter &rewriter, 51 LLVMTypeConverter &typeConverter, Location loc, 52 Value val1, Value val2, Type llvmType, int64_t rank, 53 int64_t pos) { 54 if (rank == 1) { 55 auto idxType = rewriter.getIndexType(); 56 auto constant = rewriter.create<LLVM::ConstantOp>( 57 loc, typeConverter.convertType(idxType), 58 rewriter.getIntegerAttr(idxType, pos)); 59 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 60 constant); 61 } 62 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 63 rewriter.getI64ArrayAttr(pos)); 64 } 65 66 // Helper that picks the proper sequence for inserting. 67 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 68 Value into, int64_t offset) { 69 auto vectorType = into.getType().cast<VectorType>(); 70 if (vectorType.getRank() > 1) 71 return rewriter.create<InsertOp>(loc, from, into, offset); 72 return rewriter.create<vector::InsertElementOp>( 73 loc, vectorType, from, into, 74 rewriter.create<ConstantIndexOp>(loc, offset)); 75 } 76 77 // Helper that picks the proper sequence for extracting. 78 static Value extractOne(ConversionPatternRewriter &rewriter, 79 LLVMTypeConverter &typeConverter, Location loc, 80 Value val, Type llvmType, int64_t rank, int64_t pos) { 81 if (rank == 1) { 82 auto idxType = rewriter.getIndexType(); 83 auto constant = rewriter.create<LLVM::ConstantOp>( 84 loc, typeConverter.convertType(idxType), 85 rewriter.getIntegerAttr(idxType, pos)); 86 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 87 constant); 88 } 89 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 90 rewriter.getI64ArrayAttr(pos)); 91 } 92 93 // Helper that picks the proper sequence for extracting. 94 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 95 int64_t offset) { 96 auto vectorType = vector.getType().cast<VectorType>(); 97 if (vectorType.getRank() > 1) 98 return rewriter.create<ExtractOp>(loc, vector, offset); 99 return rewriter.create<vector::ExtractElementOp>( 100 loc, vectorType.getElementType(), vector, 101 rewriter.create<ConstantIndexOp>(loc, offset)); 102 } 103 104 // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 105 // TODO: Better support for attribute subtype forwarding + slicing. 106 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 107 unsigned dropFront = 0, 108 unsigned dropBack = 0) { 109 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 110 auto range = arrayAttr.getAsRange<IntegerAttr>(); 111 SmallVector<int64_t, 4> res; 112 res.reserve(arrayAttr.size() - dropFront - dropBack); 113 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 114 it != eit; ++it) 115 res.push_back((*it).getValue().getSExtValue()); 116 return res; 117 } 118 119 // Helper that returns data layout alignment of an operation with memref. 120 template <typename T> 121 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op, 122 unsigned &align) { 123 Type elementTy = 124 typeConverter.convertType(op.getMemRefType().getElementType()); 125 if (!elementTy) 126 return failure(); 127 128 auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout(); 129 align = dataLayout.getPrefTypeAlignment( 130 LLVM::convertLLVMType(elementTy.cast<LLVM::LLVMType>())); 131 return success(); 132 } 133 134 // Helper that returns vector of pointers given a base and an index vector. 135 LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 136 LLVMTypeConverter &typeConverter, Location loc, 137 Value memref, Value indices, MemRefType memRefType, 138 VectorType vType, Type iType, Value &ptrs) { 139 // Inspect stride and offset structure. 140 // 141 // TODO: flat memory only for now, generalize 142 // 143 int64_t offset; 144 SmallVector<int64_t, 4> strides; 145 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 146 if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || 147 offset != 0 || memRefType.getMemorySpace() != 0) 148 return failure(); 149 150 // Create a vector of pointers from base and indices. 151 MemRefDescriptor memRefDescriptor(memref); 152 Value base = memRefDescriptor.alignedPtr(rewriter, loc); 153 int64_t size = vType.getDimSize(0); 154 auto pType = memRefDescriptor.getElementType(); 155 auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size); 156 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices); 157 return success(); 158 } 159 160 static LogicalResult 161 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 162 LLVMTypeConverter &typeConverter, Location loc, 163 TransferReadOp xferOp, 164 ArrayRef<Value> operands, Value dataPtr) { 165 unsigned align; 166 if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 167 return failure(); 168 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 169 return success(); 170 } 171 172 static LogicalResult 173 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 174 LLVMTypeConverter &typeConverter, Location loc, 175 TransferReadOp xferOp, ArrayRef<Value> operands, 176 Value dataPtr, Value mask) { 177 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 178 VectorType fillType = xferOp.getVectorType(); 179 Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 180 fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 181 182 Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 183 if (!vecTy) 184 return failure(); 185 186 unsigned align; 187 if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 188 return failure(); 189 190 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 191 xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 192 rewriter.getI32IntegerAttr(align)); 193 return success(); 194 } 195 196 static LogicalResult 197 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 198 LLVMTypeConverter &typeConverter, Location loc, 199 TransferWriteOp xferOp, 200 ArrayRef<Value> operands, Value dataPtr) { 201 unsigned align; 202 if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 203 return failure(); 204 auto adaptor = TransferWriteOpAdaptor(operands); 205 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 206 align); 207 return success(); 208 } 209 210 static LogicalResult 211 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 212 LLVMTypeConverter &typeConverter, Location loc, 213 TransferWriteOp xferOp, ArrayRef<Value> operands, 214 Value dataPtr, Value mask) { 215 unsigned align; 216 if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 217 return failure(); 218 219 auto adaptor = TransferWriteOpAdaptor(operands); 220 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 221 xferOp, adaptor.vector(), dataPtr, mask, 222 rewriter.getI32IntegerAttr(align)); 223 return success(); 224 } 225 226 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 227 ArrayRef<Value> operands) { 228 return TransferReadOpAdaptor(operands); 229 } 230 231 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 232 ArrayRef<Value> operands) { 233 return TransferWriteOpAdaptor(operands); 234 } 235 236 namespace { 237 238 /// Conversion pattern for a vector.matrix_multiply. 239 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 240 class VectorMatmulOpConversion : public ConvertToLLVMPattern { 241 public: 242 explicit VectorMatmulOpConversion(MLIRContext *context, 243 LLVMTypeConverter &typeConverter) 244 : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, 245 typeConverter) {} 246 247 LogicalResult 248 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 249 ConversionPatternRewriter &rewriter) const override { 250 auto matmulOp = cast<vector::MatmulOp>(op); 251 auto adaptor = vector::MatmulOpAdaptor(operands); 252 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 253 op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), 254 adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), 255 matmulOp.rhs_columns()); 256 return success(); 257 } 258 }; 259 260 /// Conversion pattern for a vector.flat_transpose. 261 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 262 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern { 263 public: 264 explicit VectorFlatTransposeOpConversion(MLIRContext *context, 265 LLVMTypeConverter &typeConverter) 266 : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(), 267 context, typeConverter) {} 268 269 LogicalResult 270 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 271 ConversionPatternRewriter &rewriter) const override { 272 auto transOp = cast<vector::FlatTransposeOp>(op); 273 auto adaptor = vector::FlatTransposeOpAdaptor(operands); 274 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 275 transOp, typeConverter.convertType(transOp.res().getType()), 276 adaptor.matrix(), transOp.rows(), transOp.columns()); 277 return success(); 278 } 279 }; 280 281 /// Conversion pattern for a vector.gather. 282 class VectorGatherOpConversion : public ConvertToLLVMPattern { 283 public: 284 explicit VectorGatherOpConversion(MLIRContext *context, 285 LLVMTypeConverter &typeConverter) 286 : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context, 287 typeConverter) {} 288 289 LogicalResult 290 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 291 ConversionPatternRewriter &rewriter) const override { 292 auto loc = op->getLoc(); 293 auto gather = cast<vector::GatherOp>(op); 294 auto adaptor = vector::GatherOpAdaptor(operands); 295 296 // Resolve alignment. 297 unsigned align; 298 if (failed(getMemRefAlignment(typeConverter, gather, align))) 299 return failure(); 300 301 // Get index ptrs. 302 VectorType vType = gather.getResultVectorType(); 303 Type iType = gather.getIndicesVectorType().getElementType(); 304 Value ptrs; 305 if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(), 306 adaptor.indices(), gather.getMemRefType(), vType, 307 iType, ptrs))) 308 return failure(); 309 310 // Replace with the gather intrinsic. 311 ValueRange v = (llvm::size(adaptor.pass_thru()) == 0) ? ValueRange({}) 312 : adaptor.pass_thru(); 313 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 314 gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), v, 315 rewriter.getI32IntegerAttr(align)); 316 return success(); 317 } 318 }; 319 320 /// Conversion pattern for a vector.scatter. 321 class VectorScatterOpConversion : public ConvertToLLVMPattern { 322 public: 323 explicit VectorScatterOpConversion(MLIRContext *context, 324 LLVMTypeConverter &typeConverter) 325 : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context, 326 typeConverter) {} 327 328 LogicalResult 329 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 330 ConversionPatternRewriter &rewriter) const override { 331 auto loc = op->getLoc(); 332 auto scatter = cast<vector::ScatterOp>(op); 333 auto adaptor = vector::ScatterOpAdaptor(operands); 334 335 // Resolve alignment. 336 unsigned align; 337 if (failed(getMemRefAlignment(typeConverter, scatter, align))) 338 return failure(); 339 340 // Get index ptrs. 341 VectorType vType = scatter.getValueVectorType(); 342 Type iType = scatter.getIndicesVectorType().getElementType(); 343 Value ptrs; 344 if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(), 345 adaptor.indices(), scatter.getMemRefType(), vType, 346 iType, ptrs))) 347 return failure(); 348 349 // Replace with the scatter intrinsic. 350 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 351 scatter, adaptor.value(), ptrs, adaptor.mask(), 352 rewriter.getI32IntegerAttr(align)); 353 return success(); 354 } 355 }; 356 357 /// Conversion pattern for all vector reductions. 358 class VectorReductionOpConversion : public ConvertToLLVMPattern { 359 public: 360 explicit VectorReductionOpConversion(MLIRContext *context, 361 LLVMTypeConverter &typeConverter, 362 bool reassociateFP) 363 : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, 364 typeConverter), 365 reassociateFPReductions(reassociateFP) {} 366 367 LogicalResult 368 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 369 ConversionPatternRewriter &rewriter) const override { 370 auto reductionOp = cast<vector::ReductionOp>(op); 371 auto kind = reductionOp.kind(); 372 Type eltType = reductionOp.dest().getType(); 373 Type llvmType = typeConverter.convertType(eltType); 374 if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) { 375 // Integer reductions: add/mul/min/max/and/or/xor. 376 if (kind == "add") 377 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>( 378 op, llvmType, operands[0]); 379 else if (kind == "mul") 380 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>( 381 op, llvmType, operands[0]); 382 else if (kind == "min") 383 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>( 384 op, llvmType, operands[0]); 385 else if (kind == "max") 386 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>( 387 op, llvmType, operands[0]); 388 else if (kind == "and") 389 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>( 390 op, llvmType, operands[0]); 391 else if (kind == "or") 392 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>( 393 op, llvmType, operands[0]); 394 else if (kind == "xor") 395 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>( 396 op, llvmType, operands[0]); 397 else 398 return failure(); 399 return success(); 400 401 } else if (eltType.isF32() || eltType.isF64()) { 402 // Floating-point reductions: add/mul/min/max 403 if (kind == "add") { 404 // Optional accumulator (or zero). 405 Value acc = operands.size() > 1 ? operands[1] 406 : rewriter.create<LLVM::ConstantOp>( 407 op->getLoc(), llvmType, 408 rewriter.getZeroAttr(eltType)); 409 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>( 410 op, llvmType, acc, operands[0], 411 rewriter.getBoolAttr(reassociateFPReductions)); 412 } else if (kind == "mul") { 413 // Optional accumulator (or one). 414 Value acc = operands.size() > 1 415 ? operands[1] 416 : rewriter.create<LLVM::ConstantOp>( 417 op->getLoc(), llvmType, 418 rewriter.getFloatAttr(eltType, 1.0)); 419 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>( 420 op, llvmType, acc, operands[0], 421 rewriter.getBoolAttr(reassociateFPReductions)); 422 } else if (kind == "min") 423 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>( 424 op, llvmType, operands[0]); 425 else if (kind == "max") 426 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>( 427 op, llvmType, operands[0]); 428 else 429 return failure(); 430 return success(); 431 } 432 return failure(); 433 } 434 435 private: 436 const bool reassociateFPReductions; 437 }; 438 439 class VectorShuffleOpConversion : public ConvertToLLVMPattern { 440 public: 441 explicit VectorShuffleOpConversion(MLIRContext *context, 442 LLVMTypeConverter &typeConverter) 443 : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, 444 typeConverter) {} 445 446 LogicalResult 447 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 448 ConversionPatternRewriter &rewriter) const override { 449 auto loc = op->getLoc(); 450 auto adaptor = vector::ShuffleOpAdaptor(operands); 451 auto shuffleOp = cast<vector::ShuffleOp>(op); 452 auto v1Type = shuffleOp.getV1VectorType(); 453 auto v2Type = shuffleOp.getV2VectorType(); 454 auto vectorType = shuffleOp.getVectorType(); 455 Type llvmType = typeConverter.convertType(vectorType); 456 auto maskArrayAttr = shuffleOp.mask(); 457 458 // Bail if result type cannot be lowered. 459 if (!llvmType) 460 return failure(); 461 462 // Get rank and dimension sizes. 463 int64_t rank = vectorType.getRank(); 464 assert(v1Type.getRank() == rank); 465 assert(v2Type.getRank() == rank); 466 int64_t v1Dim = v1Type.getDimSize(0); 467 468 // For rank 1, where both operands have *exactly* the same vector type, 469 // there is direct shuffle support in LLVM. Use it! 470 if (rank == 1 && v1Type == v2Type) { 471 Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>( 472 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 473 rewriter.replaceOp(op, shuffle); 474 return success(); 475 } 476 477 // For all other cases, insert the individual values individually. 478 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 479 int64_t insPos = 0; 480 for (auto en : llvm::enumerate(maskArrayAttr)) { 481 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 482 Value value = adaptor.v1(); 483 if (extPos >= v1Dim) { 484 extPos -= v1Dim; 485 value = adaptor.v2(); 486 } 487 Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType, 488 rank, extPos); 489 insert = insertOne(rewriter, typeConverter, loc, insert, extract, 490 llvmType, rank, insPos++); 491 } 492 rewriter.replaceOp(op, insert); 493 return success(); 494 } 495 }; 496 497 class VectorExtractElementOpConversion : public ConvertToLLVMPattern { 498 public: 499 explicit VectorExtractElementOpConversion(MLIRContext *context, 500 LLVMTypeConverter &typeConverter) 501 : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), 502 context, typeConverter) {} 503 504 LogicalResult 505 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 506 ConversionPatternRewriter &rewriter) const override { 507 auto adaptor = vector::ExtractElementOpAdaptor(operands); 508 auto extractEltOp = cast<vector::ExtractElementOp>(op); 509 auto vectorType = extractEltOp.getVectorType(); 510 auto llvmType = typeConverter.convertType(vectorType.getElementType()); 511 512 // Bail if result type cannot be lowered. 513 if (!llvmType) 514 return failure(); 515 516 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 517 op, llvmType, adaptor.vector(), adaptor.position()); 518 return success(); 519 } 520 }; 521 522 class VectorExtractOpConversion : public ConvertToLLVMPattern { 523 public: 524 explicit VectorExtractOpConversion(MLIRContext *context, 525 LLVMTypeConverter &typeConverter) 526 : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, 527 typeConverter) {} 528 529 LogicalResult 530 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 531 ConversionPatternRewriter &rewriter) const override { 532 auto loc = op->getLoc(); 533 auto adaptor = vector::ExtractOpAdaptor(operands); 534 auto extractOp = cast<vector::ExtractOp>(op); 535 auto vectorType = extractOp.getVectorType(); 536 auto resultType = extractOp.getResult().getType(); 537 auto llvmResultType = typeConverter.convertType(resultType); 538 auto positionArrayAttr = extractOp.position(); 539 540 // Bail if result type cannot be lowered. 541 if (!llvmResultType) 542 return failure(); 543 544 // One-shot extraction of vector from array (only requires extractvalue). 545 if (resultType.isa<VectorType>()) { 546 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 547 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 548 rewriter.replaceOp(op, extracted); 549 return success(); 550 } 551 552 // Potential extraction of 1-D vector from array. 553 auto *context = op->getContext(); 554 Value extracted = adaptor.vector(); 555 auto positionAttrs = positionArrayAttr.getValue(); 556 if (positionAttrs.size() > 1) { 557 auto oneDVectorType = reducedVectorTypeBack(vectorType); 558 auto nMinusOnePositionAttrs = 559 ArrayAttr::get(positionAttrs.drop_back(), context); 560 extracted = rewriter.create<LLVM::ExtractValueOp>( 561 loc, typeConverter.convertType(oneDVectorType), extracted, 562 nMinusOnePositionAttrs); 563 } 564 565 // Remaining extraction of element from 1-D LLVM vector 566 auto position = positionAttrs.back().cast<IntegerAttr>(); 567 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 568 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 569 extracted = 570 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 571 rewriter.replaceOp(op, extracted); 572 573 return success(); 574 } 575 }; 576 577 /// Conversion pattern that turns a vector.fma on a 1-D vector 578 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 579 /// This does not match vectors of n >= 2 rank. 580 /// 581 /// Example: 582 /// ``` 583 /// vector.fma %a, %a, %a : vector<8xf32> 584 /// ``` 585 /// is converted to: 586 /// ``` 587 /// llvm.intr.fmuladd %va, %va, %va: 588 /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) 589 /// -> !llvm<"<8 x float>"> 590 /// ``` 591 class VectorFMAOp1DConversion : public ConvertToLLVMPattern { 592 public: 593 explicit VectorFMAOp1DConversion(MLIRContext *context, 594 LLVMTypeConverter &typeConverter) 595 : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, 596 typeConverter) {} 597 598 LogicalResult 599 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 600 ConversionPatternRewriter &rewriter) const override { 601 auto adaptor = vector::FMAOpAdaptor(operands); 602 vector::FMAOp fmaOp = cast<vector::FMAOp>(op); 603 VectorType vType = fmaOp.getVectorType(); 604 if (vType.getRank() != 1) 605 return failure(); 606 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(), 607 adaptor.rhs(), adaptor.acc()); 608 return success(); 609 } 610 }; 611 612 class VectorInsertElementOpConversion : public ConvertToLLVMPattern { 613 public: 614 explicit VectorInsertElementOpConversion(MLIRContext *context, 615 LLVMTypeConverter &typeConverter) 616 : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), 617 context, typeConverter) {} 618 619 LogicalResult 620 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 621 ConversionPatternRewriter &rewriter) const override { 622 auto adaptor = vector::InsertElementOpAdaptor(operands); 623 auto insertEltOp = cast<vector::InsertElementOp>(op); 624 auto vectorType = insertEltOp.getDestVectorType(); 625 auto llvmType = typeConverter.convertType(vectorType); 626 627 // Bail if result type cannot be lowered. 628 if (!llvmType) 629 return failure(); 630 631 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 632 op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); 633 return success(); 634 } 635 }; 636 637 class VectorInsertOpConversion : public ConvertToLLVMPattern { 638 public: 639 explicit VectorInsertOpConversion(MLIRContext *context, 640 LLVMTypeConverter &typeConverter) 641 : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, 642 typeConverter) {} 643 644 LogicalResult 645 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 646 ConversionPatternRewriter &rewriter) const override { 647 auto loc = op->getLoc(); 648 auto adaptor = vector::InsertOpAdaptor(operands); 649 auto insertOp = cast<vector::InsertOp>(op); 650 auto sourceType = insertOp.getSourceType(); 651 auto destVectorType = insertOp.getDestVectorType(); 652 auto llvmResultType = typeConverter.convertType(destVectorType); 653 auto positionArrayAttr = insertOp.position(); 654 655 // Bail if result type cannot be lowered. 656 if (!llvmResultType) 657 return failure(); 658 659 // One-shot insertion of a vector into an array (only requires insertvalue). 660 if (sourceType.isa<VectorType>()) { 661 Value inserted = rewriter.create<LLVM::InsertValueOp>( 662 loc, llvmResultType, adaptor.dest(), adaptor.source(), 663 positionArrayAttr); 664 rewriter.replaceOp(op, inserted); 665 return success(); 666 } 667 668 // Potential extraction of 1-D vector from array. 669 auto *context = op->getContext(); 670 Value extracted = adaptor.dest(); 671 auto positionAttrs = positionArrayAttr.getValue(); 672 auto position = positionAttrs.back().cast<IntegerAttr>(); 673 auto oneDVectorType = destVectorType; 674 if (positionAttrs.size() > 1) { 675 oneDVectorType = reducedVectorTypeBack(destVectorType); 676 auto nMinusOnePositionAttrs = 677 ArrayAttr::get(positionAttrs.drop_back(), context); 678 extracted = rewriter.create<LLVM::ExtractValueOp>( 679 loc, typeConverter.convertType(oneDVectorType), extracted, 680 nMinusOnePositionAttrs); 681 } 682 683 // Insertion of an element into a 1-D LLVM vector. 684 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 685 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 686 Value inserted = rewriter.create<LLVM::InsertElementOp>( 687 loc, typeConverter.convertType(oneDVectorType), extracted, 688 adaptor.source(), constant); 689 690 // Potential insertion of resulting 1-D vector into array. 691 if (positionAttrs.size() > 1) { 692 auto nMinusOnePositionAttrs = 693 ArrayAttr::get(positionAttrs.drop_back(), context); 694 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 695 adaptor.dest(), inserted, 696 nMinusOnePositionAttrs); 697 } 698 699 rewriter.replaceOp(op, inserted); 700 return success(); 701 } 702 }; 703 704 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 705 /// 706 /// Example: 707 /// ``` 708 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 709 /// ``` 710 /// is rewritten into: 711 /// ``` 712 /// %r = splat %f0: vector<2x4xf32> 713 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 714 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 715 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 716 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 717 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 718 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 719 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 720 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 721 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 722 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 723 /// // %r3 holds the final value. 724 /// ``` 725 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 726 public: 727 using OpRewritePattern<FMAOp>::OpRewritePattern; 728 729 LogicalResult matchAndRewrite(FMAOp op, 730 PatternRewriter &rewriter) const override { 731 auto vType = op.getVectorType(); 732 if (vType.getRank() < 2) 733 return failure(); 734 735 auto loc = op.getLoc(); 736 auto elemType = vType.getElementType(); 737 Value zero = rewriter.create<ConstantOp>(loc, elemType, 738 rewriter.getZeroAttr(elemType)); 739 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 740 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 741 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 742 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 743 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 744 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 745 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 746 } 747 rewriter.replaceOp(op, desc); 748 return success(); 749 } 750 }; 751 752 // When ranks are different, InsertStridedSlice needs to extract a properly 753 // ranked vector from the destination vector into which to insert. This pattern 754 // only takes care of this part and forwards the rest of the conversion to 755 // another pattern that converts InsertStridedSlice for operands of the same 756 // rank. 757 // 758 // RewritePattern for InsertStridedSliceOp where source and destination vectors 759 // have different ranks. In this case: 760 // 1. the proper subvector is extracted from the destination vector 761 // 2. a new InsertStridedSlice op is created to insert the source in the 762 // destination subvector 763 // 3. the destination subvector is inserted back in the proper place 764 // 4. the op is replaced by the result of step 3. 765 // The new InsertStridedSlice from step 2. will be picked up by a 766 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 767 class VectorInsertStridedSliceOpDifferentRankRewritePattern 768 : public OpRewritePattern<InsertStridedSliceOp> { 769 public: 770 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 771 772 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 773 PatternRewriter &rewriter) const override { 774 auto srcType = op.getSourceVectorType(); 775 auto dstType = op.getDestVectorType(); 776 777 if (op.offsets().getValue().empty()) 778 return failure(); 779 780 auto loc = op.getLoc(); 781 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 782 assert(rankDiff >= 0); 783 if (rankDiff == 0) 784 return failure(); 785 786 int64_t rankRest = dstType.getRank() - rankDiff; 787 // Extract / insert the subvector of matching rank and InsertStridedSlice 788 // on it. 789 Value extracted = 790 rewriter.create<ExtractOp>(loc, op.dest(), 791 getI64SubArray(op.offsets(), /*dropFront=*/0, 792 /*dropFront=*/rankRest)); 793 // A different pattern will kick in for InsertStridedSlice with matching 794 // ranks. 795 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 796 loc, op.source(), extracted, 797 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 798 getI64SubArray(op.strides(), /*dropFront=*/0)); 799 rewriter.replaceOpWithNewOp<InsertOp>( 800 op, stridedSliceInnerOp.getResult(), op.dest(), 801 getI64SubArray(op.offsets(), /*dropFront=*/0, 802 /*dropFront=*/rankRest)); 803 return success(); 804 } 805 }; 806 807 // RewritePattern for InsertStridedSliceOp where source and destination vectors 808 // have the same rank. In this case, we reduce 809 // 1. the proper subvector is extracted from the destination vector 810 // 2. a new InsertStridedSlice op is created to insert the source in the 811 // destination subvector 812 // 3. the destination subvector is inserted back in the proper place 813 // 4. the op is replaced by the result of step 3. 814 // The new InsertStridedSlice from step 2. will be picked up by a 815 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 816 class VectorInsertStridedSliceOpSameRankRewritePattern 817 : public OpRewritePattern<InsertStridedSliceOp> { 818 public: 819 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 820 821 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 822 PatternRewriter &rewriter) const override { 823 auto srcType = op.getSourceVectorType(); 824 auto dstType = op.getDestVectorType(); 825 826 if (op.offsets().getValue().empty()) 827 return failure(); 828 829 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 830 assert(rankDiff >= 0); 831 if (rankDiff != 0) 832 return failure(); 833 834 if (srcType == dstType) { 835 rewriter.replaceOp(op, op.source()); 836 return success(); 837 } 838 839 int64_t offset = 840 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 841 int64_t size = srcType.getShape().front(); 842 int64_t stride = 843 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 844 845 auto loc = op.getLoc(); 846 Value res = op.dest(); 847 // For each slice of the source vector along the most major dimension. 848 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 849 off += stride, ++idx) { 850 // 1. extract the proper subvector (or element) from source 851 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 852 if (extractedSource.getType().isa<VectorType>()) { 853 // 2. If we have a vector, extract the proper subvector from destination 854 // Otherwise we are at the element level and no need to recurse. 855 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 856 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 857 // smaller rank. 858 extractedSource = rewriter.create<InsertStridedSliceOp>( 859 loc, extractedSource, extractedDest, 860 getI64SubArray(op.offsets(), /* dropFront=*/1), 861 getI64SubArray(op.strides(), /* dropFront=*/1)); 862 } 863 // 4. Insert the extractedSource into the res vector. 864 res = insertOne(rewriter, loc, extractedSource, res, off); 865 } 866 867 rewriter.replaceOp(op, res); 868 return success(); 869 } 870 /// This pattern creates recursive InsertStridedSliceOp, but the recursion is 871 /// bounded as the rank is strictly decreasing. 872 bool hasBoundedRewriteRecursion() const final { return true; } 873 }; 874 875 class VectorTypeCastOpConversion : public ConvertToLLVMPattern { 876 public: 877 explicit VectorTypeCastOpConversion(MLIRContext *context, 878 LLVMTypeConverter &typeConverter) 879 : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, 880 typeConverter) {} 881 882 LogicalResult 883 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 884 ConversionPatternRewriter &rewriter) const override { 885 auto loc = op->getLoc(); 886 vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 887 MemRefType sourceMemRefType = 888 castOp.getOperand().getType().cast<MemRefType>(); 889 MemRefType targetMemRefType = 890 castOp.getResult().getType().cast<MemRefType>(); 891 892 // Only static shape casts supported atm. 893 if (!sourceMemRefType.hasStaticShape() || 894 !targetMemRefType.hasStaticShape()) 895 return failure(); 896 897 auto llvmSourceDescriptorTy = 898 operands[0].getType().dyn_cast<LLVM::LLVMType>(); 899 if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 900 return failure(); 901 MemRefDescriptor sourceMemRef(operands[0]); 902 903 auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) 904 .dyn_cast_or_null<LLVM::LLVMType>(); 905 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 906 return failure(); 907 908 int64_t offset; 909 SmallVector<int64_t, 4> strides; 910 auto successStrides = 911 getStridesAndOffset(sourceMemRefType, strides, offset); 912 bool isContiguous = (strides.back() == 1); 913 if (isContiguous) { 914 auto sizes = sourceMemRefType.getShape(); 915 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 916 if (strides[index] != strides[index + 1] * sizes[index + 1]) { 917 isContiguous = false; 918 break; 919 } 920 } 921 } 922 // Only contiguous source tensors supported atm. 923 if (failed(successStrides) || !isContiguous) 924 return failure(); 925 926 auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 927 928 // Create descriptor. 929 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 930 Type llvmTargetElementTy = desc.getElementType(); 931 // Set allocated ptr. 932 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 933 allocated = 934 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 935 desc.setAllocatedPtr(rewriter, loc, allocated); 936 // Set aligned ptr. 937 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 938 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 939 desc.setAlignedPtr(rewriter, loc, ptr); 940 // Fill offset 0. 941 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 942 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 943 desc.setOffset(rewriter, loc, zero); 944 945 // Fill size and stride descriptors in memref. 946 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 947 int64_t index = indexedSize.index(); 948 auto sizeAttr = 949 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 950 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 951 desc.setSize(rewriter, loc, index, size); 952 auto strideAttr = 953 rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); 954 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 955 desc.setStride(rewriter, loc, index, stride); 956 } 957 958 rewriter.replaceOp(op, {desc}); 959 return success(); 960 } 961 }; 962 963 /// Conversion pattern that converts a 1-D vector transfer read/write op in a 964 /// sequence of: 965 /// 1. Bitcast or addrspacecast to vector form. 966 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 967 /// 3. Create a mask where offsetVector is compared against memref upper bound. 968 /// 4. Rewrite op as a masked read or write. 969 template <typename ConcreteOp> 970 class VectorTransferConversion : public ConvertToLLVMPattern { 971 public: 972 explicit VectorTransferConversion(MLIRContext *context, 973 LLVMTypeConverter &typeConv) 974 : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, 975 typeConv) {} 976 977 LogicalResult 978 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 979 ConversionPatternRewriter &rewriter) const override { 980 auto xferOp = cast<ConcreteOp>(op); 981 auto adaptor = getTransferOpAdapter(xferOp, operands); 982 983 if (xferOp.getVectorType().getRank() > 1 || 984 llvm::size(xferOp.indices()) == 0) 985 return failure(); 986 if (xferOp.permutation_map() != 987 AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 988 xferOp.getVectorType().getRank(), 989 op->getContext())) 990 return failure(); 991 992 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 993 994 Location loc = op->getLoc(); 995 Type i64Type = rewriter.getIntegerType(64); 996 MemRefType memRefType = xferOp.getMemRefType(); 997 998 // 1. Get the source/dst address as an LLVM vector pointer. 999 // The vector pointer would always be on address space 0, therefore 1000 // addrspacecast shall be used when source/dst memrefs are not on 1001 // address space 0. 1002 // TODO: support alignment when possible. 1003 Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), 1004 adaptor.indices(), rewriter, getModule()); 1005 auto vecTy = 1006 toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>(); 1007 Value vectorDataPtr; 1008 if (memRefType.getMemorySpace() == 0) 1009 vectorDataPtr = 1010 rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr); 1011 else 1012 vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 1013 loc, vecTy.getPointerTo(), dataPtr); 1014 1015 if (!xferOp.isMaskedDim(0)) 1016 return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc, 1017 xferOp, operands, vectorDataPtr); 1018 1019 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1020 unsigned vecWidth = vecTy.getVectorNumElements(); 1021 VectorType vectorCmpType = VectorType::get(vecWidth, i64Type); 1022 SmallVector<int64_t, 8> indices; 1023 indices.reserve(vecWidth); 1024 for (unsigned i = 0; i < vecWidth; ++i) 1025 indices.push_back(i); 1026 Value linearIndices = rewriter.create<ConstantOp>( 1027 loc, vectorCmpType, 1028 DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices))); 1029 linearIndices = rewriter.create<LLVM::DialectCastOp>( 1030 loc, toLLVMTy(vectorCmpType), linearIndices); 1031 1032 // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1033 // TODO: when the leaf transfer rank is k > 1 we need the last 1034 // `k` dimensions here. 1035 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 1036 Value offsetIndex = *(xferOp.indices().begin() + lastIndex); 1037 offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex); 1038 Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex); 1039 Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices); 1040 1041 // 4. Let dim the memref dimension, compute the vector comparison mask: 1042 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 1043 Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex); 1044 dim = rewriter.create<IndexCastOp>(loc, i64Type, dim); 1045 dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim); 1046 Value mask = 1047 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim); 1048 mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()), 1049 mask); 1050 1051 // 5. Rewrite as a masked read / write. 1052 return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp, 1053 operands, vectorDataPtr, mask); 1054 } 1055 }; 1056 1057 class VectorPrintOpConversion : public ConvertToLLVMPattern { 1058 public: 1059 explicit VectorPrintOpConversion(MLIRContext *context, 1060 LLVMTypeConverter &typeConverter) 1061 : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, 1062 typeConverter) {} 1063 1064 // Proof-of-concept lowering implementation that relies on a small 1065 // runtime support library, which only needs to provide a few 1066 // printing methods (single value for all data types, opening/closing 1067 // bracket, comma, newline). The lowering fully unrolls a vector 1068 // in terms of these elementary printing operations. The advantage 1069 // of this approach is that the library can remain unaware of all 1070 // low-level implementation details of vectors while still supporting 1071 // output of any shaped and dimensioned vector. Due to full unrolling, 1072 // this approach is less suited for very large vectors though. 1073 // 1074 // TODO: rely solely on libc in future? something else? 1075 // 1076 LogicalResult 1077 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1078 ConversionPatternRewriter &rewriter) const override { 1079 auto printOp = cast<vector::PrintOp>(op); 1080 auto adaptor = vector::PrintOpAdaptor(operands); 1081 Type printType = printOp.getPrintType(); 1082 1083 if (typeConverter.convertType(printType) == nullptr) 1084 return failure(); 1085 1086 // Make sure element type has runtime support (currently just Float/Double). 1087 VectorType vectorType = printType.dyn_cast<VectorType>(); 1088 Type eltType = vectorType ? vectorType.getElementType() : printType; 1089 int64_t rank = vectorType ? vectorType.getRank() : 0; 1090 Operation *printer; 1091 if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32)) 1092 printer = getPrintI32(op); 1093 else if (eltType.isSignlessInteger(64)) 1094 printer = getPrintI64(op); 1095 else if (eltType.isF32()) 1096 printer = getPrintFloat(op); 1097 else if (eltType.isF64()) 1098 printer = getPrintDouble(op); 1099 else 1100 return failure(); 1101 1102 // Unroll vector into elementary print calls. 1103 emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); 1104 emitCall(rewriter, op->getLoc(), getPrintNewline(op)); 1105 rewriter.eraseOp(op); 1106 return success(); 1107 } 1108 1109 private: 1110 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1111 Value value, VectorType vectorType, Operation *printer, 1112 int64_t rank) const { 1113 Location loc = op->getLoc(); 1114 if (rank == 0) { 1115 if (value.getType() == 1116 LLVM::LLVMType::getInt1Ty(typeConverter.getDialect())) { 1117 // Convert i1 (bool) to i32 so we can use the print_i32 method. 1118 // This avoids the need for a print_i1 method with an unclear ABI. 1119 auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect()); 1120 auto trueVal = rewriter.create<ConstantOp>( 1121 loc, i32Type, rewriter.getI32IntegerAttr(1)); 1122 auto falseVal = rewriter.create<ConstantOp>( 1123 loc, i32Type, rewriter.getI32IntegerAttr(0)); 1124 value = rewriter.create<SelectOp>(loc, value, trueVal, falseVal); 1125 } 1126 emitCall(rewriter, loc, printer, value); 1127 return; 1128 } 1129 1130 emitCall(rewriter, loc, getPrintOpen(op)); 1131 Operation *printComma = getPrintComma(op); 1132 int64_t dim = vectorType.getDimSize(0); 1133 for (int64_t d = 0; d < dim; ++d) { 1134 auto reducedType = 1135 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1136 auto llvmType = typeConverter.convertType( 1137 rank > 1 ? reducedType : vectorType.getElementType()); 1138 Value nestedVal = 1139 extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); 1140 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); 1141 if (d != dim - 1) 1142 emitCall(rewriter, loc, printComma); 1143 } 1144 emitCall(rewriter, loc, getPrintClose(op)); 1145 } 1146 1147 // Helper to emit a call. 1148 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1149 Operation *ref, ValueRange params = ValueRange()) { 1150 rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, 1151 rewriter.getSymbolRefAttr(ref), params); 1152 } 1153 1154 // Helper for printer method declaration (first hit) and lookup. 1155 static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect, 1156 StringRef name, ArrayRef<LLVM::LLVMType> params) { 1157 auto module = op->getParentOfType<ModuleOp>(); 1158 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1159 if (func) 1160 return func; 1161 OpBuilder moduleBuilder(module.getBodyRegion()); 1162 return moduleBuilder.create<LLVM::LLVMFuncOp>( 1163 op->getLoc(), name, 1164 LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect), 1165 params, /*isVarArg=*/false)); 1166 } 1167 1168 // Helpers for method names. 1169 Operation *getPrintI32(Operation *op) const { 1170 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1171 return getPrint(op, dialect, "print_i32", 1172 LLVM::LLVMType::getInt32Ty(dialect)); 1173 } 1174 Operation *getPrintI64(Operation *op) const { 1175 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1176 return getPrint(op, dialect, "print_i64", 1177 LLVM::LLVMType::getInt64Ty(dialect)); 1178 } 1179 Operation *getPrintFloat(Operation *op) const { 1180 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1181 return getPrint(op, dialect, "print_f32", 1182 LLVM::LLVMType::getFloatTy(dialect)); 1183 } 1184 Operation *getPrintDouble(Operation *op) const { 1185 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1186 return getPrint(op, dialect, "print_f64", 1187 LLVM::LLVMType::getDoubleTy(dialect)); 1188 } 1189 Operation *getPrintOpen(Operation *op) const { 1190 return getPrint(op, typeConverter.getDialect(), "print_open", {}); 1191 } 1192 Operation *getPrintClose(Operation *op) const { 1193 return getPrint(op, typeConverter.getDialect(), "print_close", {}); 1194 } 1195 Operation *getPrintComma(Operation *op) const { 1196 return getPrint(op, typeConverter.getDialect(), "print_comma", {}); 1197 } 1198 Operation *getPrintNewline(Operation *op) const { 1199 return getPrint(op, typeConverter.getDialect(), "print_newline", {}); 1200 } 1201 }; 1202 1203 /// Progressive lowering of ExtractStridedSliceOp to either: 1204 /// 1. extractelement + insertelement for the 1-D case 1205 /// 2. extract + optional strided_slice + insert for the n-D case. 1206 class VectorStridedSliceOpConversion 1207 : public OpRewritePattern<ExtractStridedSliceOp> { 1208 public: 1209 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 1210 1211 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 1212 PatternRewriter &rewriter) const override { 1213 auto dstType = op.getResult().getType().cast<VectorType>(); 1214 1215 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 1216 1217 int64_t offset = 1218 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1219 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 1220 int64_t stride = 1221 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1222 1223 auto loc = op.getLoc(); 1224 auto elemType = dstType.getElementType(); 1225 assert(elemType.isSignlessIntOrIndexOrFloat()); 1226 Value zero = rewriter.create<ConstantOp>(loc, elemType, 1227 rewriter.getZeroAttr(elemType)); 1228 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 1229 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1230 off += stride, ++idx) { 1231 Value extracted = extractOne(rewriter, loc, op.vector(), off); 1232 if (op.offsets().getValue().size() > 1) { 1233 extracted = rewriter.create<ExtractStridedSliceOp>( 1234 loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1), 1235 getI64SubArray(op.sizes(), /* dropFront=*/1), 1236 getI64SubArray(op.strides(), /* dropFront=*/1)); 1237 } 1238 res = insertOne(rewriter, loc, extracted, res, idx); 1239 } 1240 rewriter.replaceOp(op, {res}); 1241 return success(); 1242 } 1243 /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is 1244 /// bounded as the rank is strictly decreasing. 1245 bool hasBoundedRewriteRecursion() const final { return true; } 1246 }; 1247 1248 } // namespace 1249 1250 /// Populate the given list with patterns that convert from Vector to LLVM. 1251 void mlir::populateVectorToLLVMConversionPatterns( 1252 LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1253 bool reassociateFPReductions) { 1254 MLIRContext *ctx = converter.getDialect()->getContext(); 1255 // clang-format off 1256 patterns.insert<VectorFMAOpNDRewritePattern, 1257 VectorInsertStridedSliceOpDifferentRankRewritePattern, 1258 VectorInsertStridedSliceOpSameRankRewritePattern, 1259 VectorStridedSliceOpConversion>(ctx); 1260 patterns.insert<VectorReductionOpConversion>( 1261 ctx, converter, reassociateFPReductions); 1262 patterns 1263 .insert<VectorShuffleOpConversion, 1264 VectorExtractElementOpConversion, 1265 VectorExtractOpConversion, 1266 VectorFMAOp1DConversion, 1267 VectorInsertElementOpConversion, 1268 VectorInsertOpConversion, 1269 VectorPrintOpConversion, 1270 VectorTransferConversion<TransferReadOp>, 1271 VectorTransferConversion<TransferWriteOp>, 1272 VectorTypeCastOpConversion, 1273 VectorGatherOpConversion, 1274 VectorScatterOpConversion>(ctx, converter); 1275 // clang-format on 1276 } 1277 1278 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1279 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1280 MLIRContext *ctx = converter.getDialect()->getContext(); 1281 patterns.insert<VectorMatmulOpConversion>(ctx, converter); 1282 patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter); 1283 } 1284 1285 namespace { 1286 struct LowerVectorToLLVMPass 1287 : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> { 1288 LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { 1289 this->reassociateFPReductions = options.reassociateFPReductions; 1290 } 1291 void runOnOperation() override; 1292 }; 1293 } // namespace 1294 1295 void LowerVectorToLLVMPass::runOnOperation() { 1296 // Perform progressive lowering of operations on slices and 1297 // all contraction operations. Also applies folding and DCE. 1298 { 1299 OwningRewritePatternList patterns; 1300 populateVectorToVectorCanonicalizationPatterns(patterns, &getContext()); 1301 populateVectorSlicesLoweringPatterns(patterns, &getContext()); 1302 populateVectorContractLoweringPatterns(patterns, &getContext()); 1303 applyPatternsAndFoldGreedily(getOperation(), patterns); 1304 } 1305 1306 // Convert to the LLVM IR dialect. 1307 LLVMTypeConverter converter(&getContext()); 1308 OwningRewritePatternList patterns; 1309 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1310 populateVectorToLLVMConversionPatterns(converter, patterns, 1311 reassociateFPReductions); 1312 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1313 populateStdToLLVMConversionPatterns(converter, patterns); 1314 1315 LLVMConversionTarget target(getContext()); 1316 if (failed(applyPartialConversion(getOperation(), target, patterns))) { 1317 signalPassFailure(); 1318 } 1319 } 1320 1321 std::unique_ptr<OperationPass<ModuleOp>> 1322 mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { 1323 return std::make_unique<LowerVectorToLLVMPass>(options); 1324 } 1325