1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Vector/VectorOps.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/Target/LLVMIR/TypeTranslation.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 20 using namespace mlir; 21 using namespace mlir::vector; 22 23 // Helper to reduce vector type by one rank at front. 24 static VectorType reducedVectorTypeFront(VectorType tp) { 25 assert((tp.getRank() > 1) && "unlowerable vector type"); 26 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 27 } 28 29 // Helper to reduce vector type by *all* but one rank at back. 30 static VectorType reducedVectorTypeBack(VectorType tp) { 31 assert((tp.getRank() > 1) && "unlowerable vector type"); 32 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 33 } 34 35 // Helper that picks the proper sequence for inserting. 36 static Value insertOne(ConversionPatternRewriter &rewriter, 37 LLVMTypeConverter &typeConverter, Location loc, 38 Value val1, Value val2, Type llvmType, int64_t rank, 39 int64_t pos) { 40 if (rank == 1) { 41 auto idxType = rewriter.getIndexType(); 42 auto constant = rewriter.create<LLVM::ConstantOp>( 43 loc, typeConverter.convertType(idxType), 44 rewriter.getIntegerAttr(idxType, pos)); 45 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 46 constant); 47 } 48 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 49 rewriter.getI64ArrayAttr(pos)); 50 } 51 52 // Helper that picks the proper sequence for inserting. 53 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 54 Value into, int64_t offset) { 55 auto vectorType = into.getType().cast<VectorType>(); 56 if (vectorType.getRank() > 1) 57 return rewriter.create<InsertOp>(loc, from, into, offset); 58 return rewriter.create<vector::InsertElementOp>( 59 loc, vectorType, from, into, 60 rewriter.create<ConstantIndexOp>(loc, offset)); 61 } 62 63 // Helper that picks the proper sequence for extracting. 64 static Value extractOne(ConversionPatternRewriter &rewriter, 65 LLVMTypeConverter &typeConverter, Location loc, 66 Value val, Type llvmType, int64_t rank, int64_t pos) { 67 if (rank == 1) { 68 auto idxType = rewriter.getIndexType(); 69 auto constant = rewriter.create<LLVM::ConstantOp>( 70 loc, typeConverter.convertType(idxType), 71 rewriter.getIntegerAttr(idxType, pos)); 72 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 73 constant); 74 } 75 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 76 rewriter.getI64ArrayAttr(pos)); 77 } 78 79 // Helper that picks the proper sequence for extracting. 80 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 81 int64_t offset) { 82 auto vectorType = vector.getType().cast<VectorType>(); 83 if (vectorType.getRank() > 1) 84 return rewriter.create<ExtractOp>(loc, vector, offset); 85 return rewriter.create<vector::ExtractElementOp>( 86 loc, vectorType.getElementType(), vector, 87 rewriter.create<ConstantIndexOp>(loc, offset)); 88 } 89 90 // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 91 // TODO: Better support for attribute subtype forwarding + slicing. 92 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 93 unsigned dropFront = 0, 94 unsigned dropBack = 0) { 95 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 96 auto range = arrayAttr.getAsRange<IntegerAttr>(); 97 SmallVector<int64_t, 4> res; 98 res.reserve(arrayAttr.size() - dropFront - dropBack); 99 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 100 it != eit; ++it) 101 res.push_back((*it).getValue().getSExtValue()); 102 return res; 103 } 104 105 // Helper that returns a vector comparison that constructs a mask: 106 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 107 // 108 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 109 // much more compact, IR for this operation, but LLVM eventually 110 // generates more elaborate instructions for this intrinsic since it 111 // is very conservative on the boundary conditions. 112 static Value buildVectorComparison(ConversionPatternRewriter &rewriter, 113 Operation *op, bool enableIndexOptimizations, 114 int64_t dim, Value b, Value *off = nullptr) { 115 auto loc = op->getLoc(); 116 // If we can assume all indices fit in 32-bit, we perform the vector 117 // comparison in 32-bit to get a higher degree of SIMD parallelism. 118 // Otherwise we perform the vector comparison using 64-bit indices. 119 Value indices; 120 Type idxType; 121 if (enableIndexOptimizations) { 122 indices = rewriter.create<ConstantOp>( 123 loc, rewriter.getI32VectorAttr( 124 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)))); 125 idxType = rewriter.getI32Type(); 126 } else { 127 indices = rewriter.create<ConstantOp>( 128 loc, rewriter.getI64VectorAttr( 129 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)))); 130 idxType = rewriter.getI64Type(); 131 } 132 // Add in an offset if requested. 133 if (off) { 134 Value o = rewriter.create<IndexCastOp>(loc, idxType, *off); 135 Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o); 136 indices = rewriter.create<AddIOp>(loc, ov, indices); 137 } 138 // Construct the vector comparison. 139 Value bound = rewriter.create<IndexCastOp>(loc, idxType, b); 140 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 141 return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds); 142 } 143 144 // Helper that returns data layout alignment of a memref. 145 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 146 MemRefType memrefType, unsigned &align) { 147 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 148 if (!elementTy) 149 return failure(); 150 151 // TODO: this should use the MLIR data layout when it becomes available and 152 // stop depending on translation. 153 llvm::LLVMContext llvmContext; 154 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 155 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 156 return success(); 157 } 158 159 // Helper that returns the base address of a memref. 160 static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, 161 Value memref, MemRefType memRefType, Value &base) { 162 // Inspect stride and offset structure. 163 // 164 // TODO: flat memory only for now, generalize 165 // 166 int64_t offset; 167 SmallVector<int64_t, 4> strides; 168 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 169 if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || 170 offset != 0 || memRefType.getMemorySpace() != 0) 171 return failure(); 172 base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); 173 return success(); 174 } 175 176 // Helper that returns vector of pointers given a memref base with index vector. 177 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 178 Location loc, Value memref, Value indices, 179 MemRefType memRefType, VectorType vType, 180 Type iType, Value &ptrs) { 181 Value base; 182 if (failed(getBase(rewriter, loc, memref, memRefType, base))) 183 return failure(); 184 auto pType = MemRefDescriptor(memref).getElementPtrType(); 185 auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); 186 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices); 187 return success(); 188 } 189 190 // Casts a strided element pointer to a vector pointer. The vector pointer 191 // would always be on address space 0, therefore addrspacecast shall be 192 // used when source/dst memrefs are not on address space 0. 193 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 194 Value ptr, MemRefType memRefType, Type vt) { 195 auto pType = LLVM::LLVMPointerType::get(vt); 196 if (memRefType.getMemorySpace() == 0) 197 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 198 return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr); 199 } 200 201 static LogicalResult 202 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 203 LLVMTypeConverter &typeConverter, Location loc, 204 TransferReadOp xferOp, 205 ArrayRef<Value> operands, Value dataPtr) { 206 unsigned align; 207 if (failed(getMemRefAlignment( 208 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 209 return failure(); 210 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 211 return success(); 212 } 213 214 static LogicalResult 215 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 216 LLVMTypeConverter &typeConverter, Location loc, 217 TransferReadOp xferOp, ArrayRef<Value> operands, 218 Value dataPtr, Value mask) { 219 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 220 VectorType fillType = xferOp.getVectorType(); 221 Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 222 fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 223 224 Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 225 if (!vecTy) 226 return failure(); 227 228 unsigned align; 229 if (failed(getMemRefAlignment( 230 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 231 return failure(); 232 233 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 234 xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 235 rewriter.getI32IntegerAttr(align)); 236 return success(); 237 } 238 239 static LogicalResult 240 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 241 LLVMTypeConverter &typeConverter, Location loc, 242 TransferWriteOp xferOp, 243 ArrayRef<Value> operands, Value dataPtr) { 244 unsigned align; 245 if (failed(getMemRefAlignment( 246 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 247 return failure(); 248 auto adaptor = TransferWriteOpAdaptor(operands); 249 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 250 align); 251 return success(); 252 } 253 254 static LogicalResult 255 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 256 LLVMTypeConverter &typeConverter, Location loc, 257 TransferWriteOp xferOp, ArrayRef<Value> operands, 258 Value dataPtr, Value mask) { 259 unsigned align; 260 if (failed(getMemRefAlignment( 261 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 262 return failure(); 263 264 auto adaptor = TransferWriteOpAdaptor(operands); 265 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 266 xferOp, adaptor.vector(), dataPtr, mask, 267 rewriter.getI32IntegerAttr(align)); 268 return success(); 269 } 270 271 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 272 ArrayRef<Value> operands) { 273 return TransferReadOpAdaptor(operands); 274 } 275 276 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 277 ArrayRef<Value> operands) { 278 return TransferWriteOpAdaptor(operands); 279 } 280 281 namespace { 282 283 /// Conversion pattern for a vector.matrix_multiply. 284 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 285 class VectorMatmulOpConversion 286 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 287 public: 288 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 289 290 LogicalResult 291 matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands, 292 ConversionPatternRewriter &rewriter) const override { 293 auto adaptor = vector::MatmulOpAdaptor(operands); 294 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 295 matmulOp, typeConverter->convertType(matmulOp.res().getType()), 296 adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), 297 matmulOp.lhs_columns(), matmulOp.rhs_columns()); 298 return success(); 299 } 300 }; 301 302 /// Conversion pattern for a vector.flat_transpose. 303 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 304 class VectorFlatTransposeOpConversion 305 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 306 public: 307 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 308 309 LogicalResult 310 matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands, 311 ConversionPatternRewriter &rewriter) const override { 312 auto adaptor = vector::FlatTransposeOpAdaptor(operands); 313 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 314 transOp, typeConverter->convertType(transOp.res().getType()), 315 adaptor.matrix(), transOp.rows(), transOp.columns()); 316 return success(); 317 } 318 }; 319 320 /// Conversion pattern for a vector.maskedload. 321 class VectorMaskedLoadOpConversion 322 : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> { 323 public: 324 using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern; 325 326 LogicalResult 327 matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands, 328 ConversionPatternRewriter &rewriter) const override { 329 auto loc = load->getLoc(); 330 auto adaptor = vector::MaskedLoadOpAdaptor(operands); 331 MemRefType memRefType = load.getMemRefType(); 332 333 // Resolve alignment. 334 unsigned align; 335 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 336 return failure(); 337 338 // Resolve address. 339 auto vtype = typeConverter->convertType(load.getResultVectorType()); 340 Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 341 adaptor.indices(), rewriter); 342 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype); 343 344 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 345 load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), 346 rewriter.getI32IntegerAttr(align)); 347 return success(); 348 } 349 }; 350 351 /// Conversion pattern for a vector.maskedstore. 352 class VectorMaskedStoreOpConversion 353 : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> { 354 public: 355 using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern; 356 357 LogicalResult 358 matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands, 359 ConversionPatternRewriter &rewriter) const override { 360 auto loc = store->getLoc(); 361 auto adaptor = vector::MaskedStoreOpAdaptor(operands); 362 MemRefType memRefType = store.getMemRefType(); 363 364 // Resolve alignment. 365 unsigned align; 366 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 367 return failure(); 368 369 // Resolve address. 370 auto vtype = typeConverter->convertType(store.getValueVectorType()); 371 Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 372 adaptor.indices(), rewriter); 373 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype); 374 375 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 376 store, adaptor.value(), ptr, adaptor.mask(), 377 rewriter.getI32IntegerAttr(align)); 378 return success(); 379 } 380 }; 381 382 /// Conversion pattern for a vector.gather. 383 class VectorGatherOpConversion 384 : public ConvertOpToLLVMPattern<vector::GatherOp> { 385 public: 386 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 387 388 LogicalResult 389 matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands, 390 ConversionPatternRewriter &rewriter) const override { 391 auto loc = gather->getLoc(); 392 auto adaptor = vector::GatherOpAdaptor(operands); 393 394 // Resolve alignment. 395 unsigned align; 396 if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(), 397 align))) 398 return failure(); 399 400 // Get index ptrs. 401 VectorType vType = gather.getResultVectorType(); 402 Type iType = gather.getIndicesVectorType().getElementType(); 403 Value ptrs; 404 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 405 gather.getMemRefType(), vType, iType, ptrs))) 406 return failure(); 407 408 // Replace with the gather intrinsic. 409 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 410 gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 411 adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 412 return success(); 413 } 414 }; 415 416 /// Conversion pattern for a vector.scatter. 417 class VectorScatterOpConversion 418 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 419 public: 420 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 421 422 LogicalResult 423 matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands, 424 ConversionPatternRewriter &rewriter) const override { 425 auto loc = scatter->getLoc(); 426 auto adaptor = vector::ScatterOpAdaptor(operands); 427 428 // Resolve alignment. 429 unsigned align; 430 if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(), 431 align))) 432 return failure(); 433 434 // Get index ptrs. 435 VectorType vType = scatter.getValueVectorType(); 436 Type iType = scatter.getIndicesVectorType().getElementType(); 437 Value ptrs; 438 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 439 scatter.getMemRefType(), vType, iType, ptrs))) 440 return failure(); 441 442 // Replace with the scatter intrinsic. 443 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 444 scatter, adaptor.value(), ptrs, adaptor.mask(), 445 rewriter.getI32IntegerAttr(align)); 446 return success(); 447 } 448 }; 449 450 /// Conversion pattern for a vector.expandload. 451 class VectorExpandLoadOpConversion 452 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 453 public: 454 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 455 456 LogicalResult 457 matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands, 458 ConversionPatternRewriter &rewriter) const override { 459 auto loc = expand->getLoc(); 460 auto adaptor = vector::ExpandLoadOpAdaptor(operands); 461 MemRefType memRefType = expand.getMemRefType(); 462 463 // Resolve address. 464 auto vtype = typeConverter->convertType(expand.getResultVectorType()); 465 Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 466 adaptor.indices(), rewriter); 467 468 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 469 expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru()); 470 return success(); 471 } 472 }; 473 474 /// Conversion pattern for a vector.compressstore. 475 class VectorCompressStoreOpConversion 476 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 477 public: 478 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 479 480 LogicalResult 481 matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands, 482 ConversionPatternRewriter &rewriter) const override { 483 auto loc = compress->getLoc(); 484 auto adaptor = vector::CompressStoreOpAdaptor(operands); 485 MemRefType memRefType = compress.getMemRefType(); 486 487 // Resolve address. 488 Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 489 adaptor.indices(), rewriter); 490 491 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 492 compress, adaptor.value(), ptr, adaptor.mask()); 493 return success(); 494 } 495 }; 496 497 /// Conversion pattern for all vector reductions. 498 class VectorReductionOpConversion 499 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 500 public: 501 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 502 bool reassociateFPRed) 503 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 504 reassociateFPReductions(reassociateFPRed) {} 505 506 LogicalResult 507 matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands, 508 ConversionPatternRewriter &rewriter) const override { 509 auto kind = reductionOp.kind(); 510 Type eltType = reductionOp.dest().getType(); 511 Type llvmType = typeConverter->convertType(eltType); 512 if (eltType.isIntOrIndex()) { 513 // Integer reductions: add/mul/min/max/and/or/xor. 514 if (kind == "add") 515 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>( 516 reductionOp, llvmType, operands[0]); 517 else if (kind == "mul") 518 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>( 519 reductionOp, llvmType, operands[0]); 520 else if (kind == "min" && 521 (eltType.isIndex() || eltType.isUnsignedInteger())) 522 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 523 reductionOp, llvmType, operands[0]); 524 else if (kind == "min") 525 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 526 reductionOp, llvmType, operands[0]); 527 else if (kind == "max" && 528 (eltType.isIndex() || eltType.isUnsignedInteger())) 529 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 530 reductionOp, llvmType, operands[0]); 531 else if (kind == "max") 532 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 533 reductionOp, llvmType, operands[0]); 534 else if (kind == "and") 535 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>( 536 reductionOp, llvmType, operands[0]); 537 else if (kind == "or") 538 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>( 539 reductionOp, llvmType, operands[0]); 540 else if (kind == "xor") 541 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>( 542 reductionOp, llvmType, operands[0]); 543 else 544 return failure(); 545 return success(); 546 } 547 548 if (!eltType.isa<FloatType>()) 549 return failure(); 550 551 // Floating-point reductions: add/mul/min/max 552 if (kind == "add") { 553 // Optional accumulator (or zero). 554 Value acc = operands.size() > 1 ? operands[1] 555 : rewriter.create<LLVM::ConstantOp>( 556 reductionOp->getLoc(), llvmType, 557 rewriter.getZeroAttr(eltType)); 558 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 559 reductionOp, llvmType, acc, operands[0], 560 rewriter.getBoolAttr(reassociateFPReductions)); 561 } else if (kind == "mul") { 562 // Optional accumulator (or one). 563 Value acc = operands.size() > 1 564 ? operands[1] 565 : rewriter.create<LLVM::ConstantOp>( 566 reductionOp->getLoc(), llvmType, 567 rewriter.getFloatAttr(eltType, 1.0)); 568 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 569 reductionOp, llvmType, acc, operands[0], 570 rewriter.getBoolAttr(reassociateFPReductions)); 571 } else if (kind == "min") 572 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>( 573 reductionOp, llvmType, operands[0]); 574 else if (kind == "max") 575 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>( 576 reductionOp, llvmType, operands[0]); 577 else 578 return failure(); 579 return success(); 580 } 581 582 private: 583 const bool reassociateFPReductions; 584 }; 585 586 /// Conversion pattern for a vector.create_mask (1-D only). 587 class VectorCreateMaskOpConversion 588 : public ConvertOpToLLVMPattern<vector::CreateMaskOp> { 589 public: 590 explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv, 591 bool enableIndexOpt) 592 : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv), 593 enableIndexOptimizations(enableIndexOpt) {} 594 595 LogicalResult 596 matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands, 597 ConversionPatternRewriter &rewriter) const override { 598 auto dstType = op.getType(); 599 int64_t rank = dstType.getRank(); 600 if (rank == 1) { 601 rewriter.replaceOp( 602 op, buildVectorComparison(rewriter, op, enableIndexOptimizations, 603 dstType.getDimSize(0), operands[0])); 604 return success(); 605 } 606 return failure(); 607 } 608 609 private: 610 const bool enableIndexOptimizations; 611 }; 612 613 class VectorShuffleOpConversion 614 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 615 public: 616 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 617 618 LogicalResult 619 matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands, 620 ConversionPatternRewriter &rewriter) const override { 621 auto loc = shuffleOp->getLoc(); 622 auto adaptor = vector::ShuffleOpAdaptor(operands); 623 auto v1Type = shuffleOp.getV1VectorType(); 624 auto v2Type = shuffleOp.getV2VectorType(); 625 auto vectorType = shuffleOp.getVectorType(); 626 Type llvmType = typeConverter->convertType(vectorType); 627 auto maskArrayAttr = shuffleOp.mask(); 628 629 // Bail if result type cannot be lowered. 630 if (!llvmType) 631 return failure(); 632 633 // Get rank and dimension sizes. 634 int64_t rank = vectorType.getRank(); 635 assert(v1Type.getRank() == rank); 636 assert(v2Type.getRank() == rank); 637 int64_t v1Dim = v1Type.getDimSize(0); 638 639 // For rank 1, where both operands have *exactly* the same vector type, 640 // there is direct shuffle support in LLVM. Use it! 641 if (rank == 1 && v1Type == v2Type) { 642 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 643 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 644 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 645 return success(); 646 } 647 648 // For all other cases, insert the individual values individually. 649 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 650 int64_t insPos = 0; 651 for (auto en : llvm::enumerate(maskArrayAttr)) { 652 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 653 Value value = adaptor.v1(); 654 if (extPos >= v1Dim) { 655 extPos -= v1Dim; 656 value = adaptor.v2(); 657 } 658 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 659 llvmType, rank, extPos); 660 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 661 llvmType, rank, insPos++); 662 } 663 rewriter.replaceOp(shuffleOp, insert); 664 return success(); 665 } 666 }; 667 668 class VectorExtractElementOpConversion 669 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 670 public: 671 using ConvertOpToLLVMPattern< 672 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 673 674 LogicalResult 675 matchAndRewrite(vector::ExtractElementOp extractEltOp, 676 ArrayRef<Value> operands, 677 ConversionPatternRewriter &rewriter) const override { 678 auto adaptor = vector::ExtractElementOpAdaptor(operands); 679 auto vectorType = extractEltOp.getVectorType(); 680 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 681 682 // Bail if result type cannot be lowered. 683 if (!llvmType) 684 return failure(); 685 686 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 687 extractEltOp, llvmType, adaptor.vector(), adaptor.position()); 688 return success(); 689 } 690 }; 691 692 class VectorExtractOpConversion 693 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 694 public: 695 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 696 697 LogicalResult 698 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 699 ConversionPatternRewriter &rewriter) const override { 700 auto loc = extractOp->getLoc(); 701 auto adaptor = vector::ExtractOpAdaptor(operands); 702 auto vectorType = extractOp.getVectorType(); 703 auto resultType = extractOp.getResult().getType(); 704 auto llvmResultType = typeConverter->convertType(resultType); 705 auto positionArrayAttr = extractOp.position(); 706 707 // Bail if result type cannot be lowered. 708 if (!llvmResultType) 709 return failure(); 710 711 // One-shot extraction of vector from array (only requires extractvalue). 712 if (resultType.isa<VectorType>()) { 713 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 714 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 715 rewriter.replaceOp(extractOp, extracted); 716 return success(); 717 } 718 719 // Potential extraction of 1-D vector from array. 720 auto *context = extractOp->getContext(); 721 Value extracted = adaptor.vector(); 722 auto positionAttrs = positionArrayAttr.getValue(); 723 if (positionAttrs.size() > 1) { 724 auto oneDVectorType = reducedVectorTypeBack(vectorType); 725 auto nMinusOnePositionAttrs = 726 ArrayAttr::get(positionAttrs.drop_back(), context); 727 extracted = rewriter.create<LLVM::ExtractValueOp>( 728 loc, typeConverter->convertType(oneDVectorType), extracted, 729 nMinusOnePositionAttrs); 730 } 731 732 // Remaining extraction of element from 1-D LLVM vector 733 auto position = positionAttrs.back().cast<IntegerAttr>(); 734 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 735 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 736 extracted = 737 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 738 rewriter.replaceOp(extractOp, extracted); 739 740 return success(); 741 } 742 }; 743 744 /// Conversion pattern that turns a vector.fma on a 1-D vector 745 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 746 /// This does not match vectors of n >= 2 rank. 747 /// 748 /// Example: 749 /// ``` 750 /// vector.fma %a, %a, %a : vector<8xf32> 751 /// ``` 752 /// is converted to: 753 /// ``` 754 /// llvm.intr.fmuladd %va, %va, %va: 755 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 756 /// -> !llvm."<8 x f32>"> 757 /// ``` 758 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 759 public: 760 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 761 762 LogicalResult 763 matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, 764 ConversionPatternRewriter &rewriter) const override { 765 auto adaptor = vector::FMAOpAdaptor(operands); 766 VectorType vType = fmaOp.getVectorType(); 767 if (vType.getRank() != 1) 768 return failure(); 769 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), 770 adaptor.rhs(), adaptor.acc()); 771 return success(); 772 } 773 }; 774 775 class VectorInsertElementOpConversion 776 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 777 public: 778 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 779 780 LogicalResult 781 matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands, 782 ConversionPatternRewriter &rewriter) const override { 783 auto adaptor = vector::InsertElementOpAdaptor(operands); 784 auto vectorType = insertEltOp.getDestVectorType(); 785 auto llvmType = typeConverter->convertType(vectorType); 786 787 // Bail if result type cannot be lowered. 788 if (!llvmType) 789 return failure(); 790 791 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 792 insertEltOp, llvmType, adaptor.dest(), adaptor.source(), 793 adaptor.position()); 794 return success(); 795 } 796 }; 797 798 class VectorInsertOpConversion 799 : public ConvertOpToLLVMPattern<vector::InsertOp> { 800 public: 801 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 802 803 LogicalResult 804 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 805 ConversionPatternRewriter &rewriter) const override { 806 auto loc = insertOp->getLoc(); 807 auto adaptor = vector::InsertOpAdaptor(operands); 808 auto sourceType = insertOp.getSourceType(); 809 auto destVectorType = insertOp.getDestVectorType(); 810 auto llvmResultType = typeConverter->convertType(destVectorType); 811 auto positionArrayAttr = insertOp.position(); 812 813 // Bail if result type cannot be lowered. 814 if (!llvmResultType) 815 return failure(); 816 817 // One-shot insertion of a vector into an array (only requires insertvalue). 818 if (sourceType.isa<VectorType>()) { 819 Value inserted = rewriter.create<LLVM::InsertValueOp>( 820 loc, llvmResultType, adaptor.dest(), adaptor.source(), 821 positionArrayAttr); 822 rewriter.replaceOp(insertOp, inserted); 823 return success(); 824 } 825 826 // Potential extraction of 1-D vector from array. 827 auto *context = insertOp->getContext(); 828 Value extracted = adaptor.dest(); 829 auto positionAttrs = positionArrayAttr.getValue(); 830 auto position = positionAttrs.back().cast<IntegerAttr>(); 831 auto oneDVectorType = destVectorType; 832 if (positionAttrs.size() > 1) { 833 oneDVectorType = reducedVectorTypeBack(destVectorType); 834 auto nMinusOnePositionAttrs = 835 ArrayAttr::get(positionAttrs.drop_back(), context); 836 extracted = rewriter.create<LLVM::ExtractValueOp>( 837 loc, typeConverter->convertType(oneDVectorType), extracted, 838 nMinusOnePositionAttrs); 839 } 840 841 // Insertion of an element into a 1-D LLVM vector. 842 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 843 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 844 Value inserted = rewriter.create<LLVM::InsertElementOp>( 845 loc, typeConverter->convertType(oneDVectorType), extracted, 846 adaptor.source(), constant); 847 848 // Potential insertion of resulting 1-D vector into array. 849 if (positionAttrs.size() > 1) { 850 auto nMinusOnePositionAttrs = 851 ArrayAttr::get(positionAttrs.drop_back(), context); 852 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 853 adaptor.dest(), inserted, 854 nMinusOnePositionAttrs); 855 } 856 857 rewriter.replaceOp(insertOp, inserted); 858 return success(); 859 } 860 }; 861 862 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 863 /// 864 /// Example: 865 /// ``` 866 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 867 /// ``` 868 /// is rewritten into: 869 /// ``` 870 /// %r = splat %f0: vector<2x4xf32> 871 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 872 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 873 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 874 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 875 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 876 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 877 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 878 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 879 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 880 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 881 /// // %r3 holds the final value. 882 /// ``` 883 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 884 public: 885 using OpRewritePattern<FMAOp>::OpRewritePattern; 886 887 LogicalResult matchAndRewrite(FMAOp op, 888 PatternRewriter &rewriter) const override { 889 auto vType = op.getVectorType(); 890 if (vType.getRank() < 2) 891 return failure(); 892 893 auto loc = op.getLoc(); 894 auto elemType = vType.getElementType(); 895 Value zero = rewriter.create<ConstantOp>(loc, elemType, 896 rewriter.getZeroAttr(elemType)); 897 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 898 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 899 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 900 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 901 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 902 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 903 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 904 } 905 rewriter.replaceOp(op, desc); 906 return success(); 907 } 908 }; 909 910 // When ranks are different, InsertStridedSlice needs to extract a properly 911 // ranked vector from the destination vector into which to insert. This pattern 912 // only takes care of this part and forwards the rest of the conversion to 913 // another pattern that converts InsertStridedSlice for operands of the same 914 // rank. 915 // 916 // RewritePattern for InsertStridedSliceOp where source and destination vectors 917 // have different ranks. In this case: 918 // 1. the proper subvector is extracted from the destination vector 919 // 2. a new InsertStridedSlice op is created to insert the source in the 920 // destination subvector 921 // 3. the destination subvector is inserted back in the proper place 922 // 4. the op is replaced by the result of step 3. 923 // The new InsertStridedSlice from step 2. will be picked up by a 924 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 925 class VectorInsertStridedSliceOpDifferentRankRewritePattern 926 : public OpRewritePattern<InsertStridedSliceOp> { 927 public: 928 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 929 930 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 931 PatternRewriter &rewriter) const override { 932 auto srcType = op.getSourceVectorType(); 933 auto dstType = op.getDestVectorType(); 934 935 if (op.offsets().getValue().empty()) 936 return failure(); 937 938 auto loc = op.getLoc(); 939 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 940 assert(rankDiff >= 0); 941 if (rankDiff == 0) 942 return failure(); 943 944 int64_t rankRest = dstType.getRank() - rankDiff; 945 // Extract / insert the subvector of matching rank and InsertStridedSlice 946 // on it. 947 Value extracted = 948 rewriter.create<ExtractOp>(loc, op.dest(), 949 getI64SubArray(op.offsets(), /*dropFront=*/0, 950 /*dropBack=*/rankRest)); 951 // A different pattern will kick in for InsertStridedSlice with matching 952 // ranks. 953 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 954 loc, op.source(), extracted, 955 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 956 getI64SubArray(op.strides(), /*dropFront=*/0)); 957 rewriter.replaceOpWithNewOp<InsertOp>( 958 op, stridedSliceInnerOp.getResult(), op.dest(), 959 getI64SubArray(op.offsets(), /*dropFront=*/0, 960 /*dropBack=*/rankRest)); 961 return success(); 962 } 963 }; 964 965 // RewritePattern for InsertStridedSliceOp where source and destination vectors 966 // have the same rank. In this case, we reduce 967 // 1. the proper subvector is extracted from the destination vector 968 // 2. a new InsertStridedSlice op is created to insert the source in the 969 // destination subvector 970 // 3. the destination subvector is inserted back in the proper place 971 // 4. the op is replaced by the result of step 3. 972 // The new InsertStridedSlice from step 2. will be picked up by a 973 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 974 class VectorInsertStridedSliceOpSameRankRewritePattern 975 : public OpRewritePattern<InsertStridedSliceOp> { 976 public: 977 VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx) 978 : OpRewritePattern<InsertStridedSliceOp>(ctx) { 979 // This pattern creates recursive InsertStridedSliceOp, but the recursion is 980 // bounded as the rank is strictly decreasing. 981 setHasBoundedRewriteRecursion(); 982 } 983 984 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 985 PatternRewriter &rewriter) const override { 986 auto srcType = op.getSourceVectorType(); 987 auto dstType = op.getDestVectorType(); 988 989 if (op.offsets().getValue().empty()) 990 return failure(); 991 992 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 993 assert(rankDiff >= 0); 994 if (rankDiff != 0) 995 return failure(); 996 997 if (srcType == dstType) { 998 rewriter.replaceOp(op, op.source()); 999 return success(); 1000 } 1001 1002 int64_t offset = 1003 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1004 int64_t size = srcType.getShape().front(); 1005 int64_t stride = 1006 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1007 1008 auto loc = op.getLoc(); 1009 Value res = op.dest(); 1010 // For each slice of the source vector along the most major dimension. 1011 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1012 off += stride, ++idx) { 1013 // 1. extract the proper subvector (or element) from source 1014 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 1015 if (extractedSource.getType().isa<VectorType>()) { 1016 // 2. If we have a vector, extract the proper subvector from destination 1017 // Otherwise we are at the element level and no need to recurse. 1018 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 1019 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 1020 // smaller rank. 1021 extractedSource = rewriter.create<InsertStridedSliceOp>( 1022 loc, extractedSource, extractedDest, 1023 getI64SubArray(op.offsets(), /* dropFront=*/1), 1024 getI64SubArray(op.strides(), /* dropFront=*/1)); 1025 } 1026 // 4. Insert the extractedSource into the res vector. 1027 res = insertOne(rewriter, loc, extractedSource, res, off); 1028 } 1029 1030 rewriter.replaceOp(op, res); 1031 return success(); 1032 } 1033 }; 1034 1035 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1036 /// static layout. 1037 static llvm::Optional<SmallVector<int64_t, 4>> 1038 computeContiguousStrides(MemRefType memRefType) { 1039 int64_t offset; 1040 SmallVector<int64_t, 4> strides; 1041 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1042 return None; 1043 if (!strides.empty() && strides.back() != 1) 1044 return None; 1045 // If no layout or identity layout, this is contiguous by definition. 1046 if (memRefType.getAffineMaps().empty() || 1047 memRefType.getAffineMaps().front().isIdentity()) 1048 return strides; 1049 1050 // Otherwise, we must determine contiguity form shapes. This can only ever 1051 // work in static cases because MemRefType is underspecified to represent 1052 // contiguous dynamic shapes in other ways than with just empty/identity 1053 // layout. 1054 auto sizes = memRefType.getShape(); 1055 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 1056 if (ShapedType::isDynamic(sizes[index + 1]) || 1057 ShapedType::isDynamicStrideOrOffset(strides[index]) || 1058 ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 1059 return None; 1060 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1061 return None; 1062 } 1063 return strides; 1064 } 1065 1066 class VectorTypeCastOpConversion 1067 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1068 public: 1069 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1070 1071 LogicalResult 1072 matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands, 1073 ConversionPatternRewriter &rewriter) const override { 1074 auto loc = castOp->getLoc(); 1075 MemRefType sourceMemRefType = 1076 castOp.getOperand().getType().cast<MemRefType>(); 1077 MemRefType targetMemRefType = castOp.getType(); 1078 1079 // Only static shape casts supported atm. 1080 if (!sourceMemRefType.hasStaticShape() || 1081 !targetMemRefType.hasStaticShape()) 1082 return failure(); 1083 1084 auto llvmSourceDescriptorTy = 1085 operands[0].getType().dyn_cast<LLVM::LLVMStructType>(); 1086 if (!llvmSourceDescriptorTy) 1087 return failure(); 1088 MemRefDescriptor sourceMemRef(operands[0]); 1089 1090 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1091 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1092 if (!llvmTargetDescriptorTy) 1093 return failure(); 1094 1095 // Only contiguous source buffers supported atm. 1096 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1097 if (!sourceStrides) 1098 return failure(); 1099 auto targetStrides = computeContiguousStrides(targetMemRefType); 1100 if (!targetStrides) 1101 return failure(); 1102 // Only support static strides for now, regardless of contiguity. 1103 if (llvm::any_of(*targetStrides, [](int64_t stride) { 1104 return ShapedType::isDynamicStrideOrOffset(stride); 1105 })) 1106 return failure(); 1107 1108 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1109 1110 // Create descriptor. 1111 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1112 Type llvmTargetElementTy = desc.getElementPtrType(); 1113 // Set allocated ptr. 1114 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1115 allocated = 1116 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1117 desc.setAllocatedPtr(rewriter, loc, allocated); 1118 // Set aligned ptr. 1119 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1120 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1121 desc.setAlignedPtr(rewriter, loc, ptr); 1122 // Fill offset 0. 1123 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1124 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1125 desc.setOffset(rewriter, loc, zero); 1126 1127 // Fill size and stride descriptors in memref. 1128 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 1129 int64_t index = indexedSize.index(); 1130 auto sizeAttr = 1131 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1132 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1133 desc.setSize(rewriter, loc, index, size); 1134 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1135 (*targetStrides)[index]); 1136 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1137 desc.setStride(rewriter, loc, index, stride); 1138 } 1139 1140 rewriter.replaceOp(castOp, {desc}); 1141 return success(); 1142 } 1143 }; 1144 1145 /// Conversion pattern that converts a 1-D vector transfer read/write op in a 1146 /// sequence of: 1147 /// 1. Get the source/dst address as an LLVM vector pointer. 1148 /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1149 /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1150 /// 4. Create a mask where offsetVector is compared against memref upper bound. 1151 /// 5. Rewrite op as a masked read or write. 1152 template <typename ConcreteOp> 1153 class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> { 1154 public: 1155 explicit VectorTransferConversion(LLVMTypeConverter &typeConv, 1156 bool enableIndexOpt) 1157 : ConvertOpToLLVMPattern<ConcreteOp>(typeConv), 1158 enableIndexOptimizations(enableIndexOpt) {} 1159 1160 LogicalResult 1161 matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands, 1162 ConversionPatternRewriter &rewriter) const override { 1163 auto adaptor = getTransferOpAdapter(xferOp, operands); 1164 1165 if (xferOp.getVectorType().getRank() > 1 || 1166 llvm::size(xferOp.indices()) == 0) 1167 return failure(); 1168 if (xferOp.permutation_map() != 1169 AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 1170 xferOp.getVectorType().getRank(), 1171 xferOp->getContext())) 1172 return failure(); 1173 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 1174 if (!memRefType) 1175 return failure(); 1176 // Only contiguous source tensors supported atm. 1177 auto strides = computeContiguousStrides(memRefType); 1178 if (!strides) 1179 return failure(); 1180 1181 auto toLLVMTy = [&](Type t) { 1182 return this->getTypeConverter()->convertType(t); 1183 }; 1184 1185 Location loc = xferOp->getLoc(); 1186 1187 if (auto memrefVectorElementType = 1188 memRefType.getElementType().template dyn_cast<VectorType>()) { 1189 // Memref has vector element type. 1190 if (memrefVectorElementType.getElementType() != 1191 xferOp.getVectorType().getElementType()) 1192 return failure(); 1193 #ifndef NDEBUG 1194 // Check that memref vector type is a suffix of 'vectorType. 1195 unsigned memrefVecEltRank = memrefVectorElementType.getRank(); 1196 unsigned resultVecRank = xferOp.getVectorType().getRank(); 1197 assert(memrefVecEltRank <= resultVecRank); 1198 // TODO: Move this to isSuffix in Vector/Utils.h. 1199 unsigned rankOffset = resultVecRank - memrefVecEltRank; 1200 auto memrefVecEltShape = memrefVectorElementType.getShape(); 1201 auto resultVecShape = xferOp.getVectorType().getShape(); 1202 for (unsigned i = 0; i < memrefVecEltRank; ++i) 1203 assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && 1204 "memref vector element shape should match suffix of vector " 1205 "result shape."); 1206 #endif // ifndef NDEBUG 1207 } 1208 1209 // 1. Get the source/dst address as an LLVM vector pointer. 1210 VectorType vtp = xferOp.getVectorType(); 1211 Value dataPtr = this->getStridedElementPtr( 1212 loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); 1213 Value vectorDataPtr = 1214 castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp)); 1215 1216 if (!xferOp.isMaskedDim(0)) 1217 return replaceTransferOpWithLoadOrStore(rewriter, 1218 *this->getTypeConverter(), loc, 1219 xferOp, operands, vectorDataPtr); 1220 1221 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1222 // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1223 // 4. Let dim the memref dimension, compute the vector comparison mask: 1224 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 1225 // 1226 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 1227 // dimensions here. 1228 unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue(); 1229 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 1230 Value off = xferOp.indices()[lastIndex]; 1231 Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex); 1232 Value mask = buildVectorComparison( 1233 rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off); 1234 1235 // 5. Rewrite as a masked read / write. 1236 return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, 1237 xferOp, operands, vectorDataPtr, mask); 1238 } 1239 1240 private: 1241 const bool enableIndexOptimizations; 1242 }; 1243 1244 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1245 public: 1246 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1247 1248 // Proof-of-concept lowering implementation that relies on a small 1249 // runtime support library, which only needs to provide a few 1250 // printing methods (single value for all data types, opening/closing 1251 // bracket, comma, newline). The lowering fully unrolls a vector 1252 // in terms of these elementary printing operations. The advantage 1253 // of this approach is that the library can remain unaware of all 1254 // low-level implementation details of vectors while still supporting 1255 // output of any shaped and dimensioned vector. Due to full unrolling, 1256 // this approach is less suited for very large vectors though. 1257 // 1258 // TODO: rely solely on libc in future? something else? 1259 // 1260 LogicalResult 1261 matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands, 1262 ConversionPatternRewriter &rewriter) const override { 1263 auto adaptor = vector::PrintOpAdaptor(operands); 1264 Type printType = printOp.getPrintType(); 1265 1266 if (typeConverter->convertType(printType) == nullptr) 1267 return failure(); 1268 1269 // Make sure element type has runtime support. 1270 PrintConversion conversion = PrintConversion::None; 1271 VectorType vectorType = printType.dyn_cast<VectorType>(); 1272 Type eltType = vectorType ? vectorType.getElementType() : printType; 1273 Operation *printer; 1274 if (eltType.isF32()) { 1275 printer = getPrintFloat(printOp); 1276 } else if (eltType.isF64()) { 1277 printer = getPrintDouble(printOp); 1278 } else if (eltType.isIndex()) { 1279 printer = getPrintU64(printOp); 1280 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1281 // Integers need a zero or sign extension on the operand 1282 // (depending on the source type) as well as a signed or 1283 // unsigned print method. Up to 64-bit is supported. 1284 unsigned width = intTy.getWidth(); 1285 if (intTy.isUnsigned()) { 1286 if (width <= 64) { 1287 if (width < 64) 1288 conversion = PrintConversion::ZeroExt64; 1289 printer = getPrintU64(printOp); 1290 } else { 1291 return failure(); 1292 } 1293 } else { 1294 assert(intTy.isSignless() || intTy.isSigned()); 1295 if (width <= 64) { 1296 // Note that we *always* zero extend booleans (1-bit integers), 1297 // so that true/false is printed as 1/0 rather than -1/0. 1298 if (width == 1) 1299 conversion = PrintConversion::ZeroExt64; 1300 else if (width < 64) 1301 conversion = PrintConversion::SignExt64; 1302 printer = getPrintI64(printOp); 1303 } else { 1304 return failure(); 1305 } 1306 } 1307 } else { 1308 return failure(); 1309 } 1310 1311 // Unroll vector into elementary print calls. 1312 int64_t rank = vectorType ? vectorType.getRank() : 0; 1313 emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, 1314 conversion); 1315 emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp)); 1316 rewriter.eraseOp(printOp); 1317 return success(); 1318 } 1319 1320 private: 1321 enum class PrintConversion { 1322 // clang-format off 1323 None, 1324 ZeroExt64, 1325 SignExt64 1326 // clang-format on 1327 }; 1328 1329 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1330 Value value, VectorType vectorType, Operation *printer, 1331 int64_t rank, PrintConversion conversion) const { 1332 Location loc = op->getLoc(); 1333 if (rank == 0) { 1334 switch (conversion) { 1335 case PrintConversion::ZeroExt64: 1336 value = rewriter.create<ZeroExtendIOp>( 1337 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1338 break; 1339 case PrintConversion::SignExt64: 1340 value = rewriter.create<SignExtendIOp>( 1341 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1342 break; 1343 case PrintConversion::None: 1344 break; 1345 } 1346 emitCall(rewriter, loc, printer, value); 1347 return; 1348 } 1349 1350 emitCall(rewriter, loc, getPrintOpen(op)); 1351 Operation *printComma = getPrintComma(op); 1352 int64_t dim = vectorType.getDimSize(0); 1353 for (int64_t d = 0; d < dim; ++d) { 1354 auto reducedType = 1355 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1356 auto llvmType = typeConverter->convertType( 1357 rank > 1 ? reducedType : vectorType.getElementType()); 1358 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1359 llvmType, rank, d); 1360 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1361 conversion); 1362 if (d != dim - 1) 1363 emitCall(rewriter, loc, printComma); 1364 } 1365 emitCall(rewriter, loc, getPrintClose(op)); 1366 } 1367 1368 // Helper to emit a call. 1369 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1370 Operation *ref, ValueRange params = ValueRange()) { 1371 rewriter.create<LLVM::CallOp>(loc, TypeRange(), 1372 rewriter.getSymbolRefAttr(ref), params); 1373 } 1374 1375 // Helper for printer method declaration (first hit) and lookup. 1376 static Operation *getPrint(Operation *op, StringRef name, 1377 ArrayRef<Type> params) { 1378 auto module = op->getParentOfType<ModuleOp>(); 1379 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1380 if (func) 1381 return func; 1382 OpBuilder moduleBuilder(module.getBodyRegion()); 1383 return moduleBuilder.create<LLVM::LLVMFuncOp>( 1384 op->getLoc(), name, 1385 LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()), 1386 params)); 1387 } 1388 1389 // Helpers for method names. 1390 Operation *getPrintI64(Operation *op) const { 1391 return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64)); 1392 } 1393 Operation *getPrintU64(Operation *op) const { 1394 return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64)); 1395 } 1396 Operation *getPrintFloat(Operation *op) const { 1397 return getPrint(op, "printF32", Float32Type::get(op->getContext())); 1398 } 1399 Operation *getPrintDouble(Operation *op) const { 1400 return getPrint(op, "printF64", Float64Type::get(op->getContext())); 1401 } 1402 Operation *getPrintOpen(Operation *op) const { 1403 return getPrint(op, "printOpen", {}); 1404 } 1405 Operation *getPrintClose(Operation *op) const { 1406 return getPrint(op, "printClose", {}); 1407 } 1408 Operation *getPrintComma(Operation *op) const { 1409 return getPrint(op, "printComma", {}); 1410 } 1411 Operation *getPrintNewline(Operation *op) const { 1412 return getPrint(op, "printNewline", {}); 1413 } 1414 }; 1415 1416 /// Progressive lowering of ExtractStridedSliceOp to either: 1417 /// 1. express single offset extract as a direct shuffle. 1418 /// 2. extract + lower rank strided_slice + insert for the n-D case. 1419 class VectorExtractStridedSliceOpConversion 1420 : public OpRewritePattern<ExtractStridedSliceOp> { 1421 public: 1422 VectorExtractStridedSliceOpConversion(MLIRContext *ctx) 1423 : OpRewritePattern<ExtractStridedSliceOp>(ctx) { 1424 // This pattern creates recursive ExtractStridedSliceOp, but the recursion 1425 // is bounded as the rank is strictly decreasing. 1426 setHasBoundedRewriteRecursion(); 1427 } 1428 1429 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 1430 PatternRewriter &rewriter) const override { 1431 auto dstType = op.getType(); 1432 1433 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 1434 1435 int64_t offset = 1436 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1437 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 1438 int64_t stride = 1439 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1440 1441 auto loc = op.getLoc(); 1442 auto elemType = dstType.getElementType(); 1443 assert(elemType.isSignlessIntOrIndexOrFloat()); 1444 1445 // Single offset can be more efficiently shuffled. 1446 if (op.offsets().getValue().size() == 1) { 1447 SmallVector<int64_t, 4> offsets; 1448 offsets.reserve(size); 1449 for (int64_t off = offset, e = offset + size * stride; off < e; 1450 off += stride) 1451 offsets.push_back(off); 1452 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 1453 op.vector(), 1454 rewriter.getI64ArrayAttr(offsets)); 1455 return success(); 1456 } 1457 1458 // Extract/insert on a lower ranked extract strided slice op. 1459 Value zero = rewriter.create<ConstantOp>(loc, elemType, 1460 rewriter.getZeroAttr(elemType)); 1461 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 1462 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1463 off += stride, ++idx) { 1464 Value one = extractOne(rewriter, loc, op.vector(), off); 1465 Value extracted = rewriter.create<ExtractStridedSliceOp>( 1466 loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 1467 getI64SubArray(op.sizes(), /* dropFront=*/1), 1468 getI64SubArray(op.strides(), /* dropFront=*/1)); 1469 res = insertOne(rewriter, loc, extracted, res, idx); 1470 } 1471 rewriter.replaceOp(op, res); 1472 return success(); 1473 } 1474 }; 1475 1476 } // namespace 1477 1478 /// Populate the given list with patterns that convert from Vector to LLVM. 1479 void mlir::populateVectorToLLVMConversionPatterns( 1480 LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1481 bool reassociateFPReductions, bool enableIndexOptimizations) { 1482 MLIRContext *ctx = converter.getDialect()->getContext(); 1483 // clang-format off 1484 patterns.insert<VectorFMAOpNDRewritePattern, 1485 VectorInsertStridedSliceOpDifferentRankRewritePattern, 1486 VectorInsertStridedSliceOpSameRankRewritePattern, 1487 VectorExtractStridedSliceOpConversion>(ctx); 1488 patterns.insert<VectorReductionOpConversion>( 1489 converter, reassociateFPReductions); 1490 patterns.insert<VectorCreateMaskOpConversion, 1491 VectorTransferConversion<TransferReadOp>, 1492 VectorTransferConversion<TransferWriteOp>>( 1493 converter, enableIndexOptimizations); 1494 patterns 1495 .insert<VectorShuffleOpConversion, 1496 VectorExtractElementOpConversion, 1497 VectorExtractOpConversion, 1498 VectorFMAOp1DConversion, 1499 VectorInsertElementOpConversion, 1500 VectorInsertOpConversion, 1501 VectorPrintOpConversion, 1502 VectorTypeCastOpConversion, 1503 VectorMaskedLoadOpConversion, 1504 VectorMaskedStoreOpConversion, 1505 VectorGatherOpConversion, 1506 VectorScatterOpConversion, 1507 VectorExpandLoadOpConversion, 1508 VectorCompressStoreOpConversion>(converter); 1509 // clang-format on 1510 } 1511 1512 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1513 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1514 patterns.insert<VectorMatmulOpConversion>(converter); 1515 patterns.insert<VectorFlatTransposeOpConversion>(converter); 1516 } 1517