1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10
11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Support/MathExtras.h"
21 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
22 #include "mlir/Transforms/DialectConversion.h"
23
24 using namespace mlir;
25 using namespace mlir::vector;
26
27 // Helper to reduce vector type by one rank at front.
reducedVectorTypeFront(VectorType tp)28 static VectorType reducedVectorTypeFront(VectorType tp) {
29 assert((tp.getRank() > 1) && "unlowerable vector type");
30 unsigned numScalableDims = tp.getNumScalableDims();
31 if (tp.getShape().size() == numScalableDims)
32 --numScalableDims;
33 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
34 numScalableDims);
35 }
36
37 // Helper to reduce vector type by *all* but one rank at back.
reducedVectorTypeBack(VectorType tp)38 static VectorType reducedVectorTypeBack(VectorType tp) {
39 assert((tp.getRank() > 1) && "unlowerable vector type");
40 unsigned numScalableDims = tp.getNumScalableDims();
41 if (numScalableDims > 0)
42 --numScalableDims;
43 return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
44 numScalableDims);
45 }
46
47 // Helper that picks the proper sequence for inserting.
insertOne(ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Location loc,Value val1,Value val2,Type llvmType,int64_t rank,int64_t pos)48 static Value insertOne(ConversionPatternRewriter &rewriter,
49 LLVMTypeConverter &typeConverter, Location loc,
50 Value val1, Value val2, Type llvmType, int64_t rank,
51 int64_t pos) {
52 assert(rank > 0 && "0-D vector corner case should have been handled already");
53 if (rank == 1) {
54 auto idxType = rewriter.getIndexType();
55 auto constant = rewriter.create<LLVM::ConstantOp>(
56 loc, typeConverter.convertType(idxType),
57 rewriter.getIntegerAttr(idxType, pos));
58 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
59 constant);
60 }
61 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
62 rewriter.getI64ArrayAttr(pos));
63 }
64
65 // Helper that picks the proper sequence for extracting.
extractOne(ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Location loc,Value val,Type llvmType,int64_t rank,int64_t pos)66 static Value extractOne(ConversionPatternRewriter &rewriter,
67 LLVMTypeConverter &typeConverter, Location loc,
68 Value val, Type llvmType, int64_t rank, int64_t pos) {
69 if (rank <= 1) {
70 auto idxType = rewriter.getIndexType();
71 auto constant = rewriter.create<LLVM::ConstantOp>(
72 loc, typeConverter.convertType(idxType),
73 rewriter.getIntegerAttr(idxType, pos));
74 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
75 constant);
76 }
77 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
78 rewriter.getI64ArrayAttr(pos));
79 }
80
81 // Helper that returns data layout alignment of a memref.
getMemRefAlignment(LLVMTypeConverter & typeConverter,MemRefType memrefType,unsigned & align)82 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
83 MemRefType memrefType, unsigned &align) {
84 Type elementTy = typeConverter.convertType(memrefType.getElementType());
85 if (!elementTy)
86 return failure();
87
88 // TODO: this should use the MLIR data layout when it becomes available and
89 // stop depending on translation.
90 llvm::LLVMContext llvmContext;
91 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
92 .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
93 return success();
94 }
95
96 // Add an index vector component to a base pointer. This almost always succeeds
97 // unless the last stride is non-unit or the memory space is not zero.
getIndexedPtrs(ConversionPatternRewriter & rewriter,Location loc,Value memref,Value base,Value index,MemRefType memRefType,VectorType vType,Value & ptrs)98 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
99 Location loc, Value memref, Value base,
100 Value index, MemRefType memRefType,
101 VectorType vType, Value &ptrs) {
102 int64_t offset;
103 SmallVector<int64_t, 4> strides;
104 auto successStrides = getStridesAndOffset(memRefType, strides, offset);
105 if (failed(successStrides) || strides.back() != 1 ||
106 memRefType.getMemorySpaceAsInt() != 0)
107 return failure();
108 auto pType = MemRefDescriptor(memref).getElementPtrType();
109 auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
110 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
111 return success();
112 }
113
114 // Casts a strided element pointer to a vector pointer. The vector pointer
115 // will be in the same address space as the incoming memref type.
castDataPtr(ConversionPatternRewriter & rewriter,Location loc,Value ptr,MemRefType memRefType,Type vt)116 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
117 Value ptr, MemRefType memRefType, Type vt) {
118 auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
119 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
120 }
121
122 namespace {
123
124 /// Trivial Vector to LLVM conversions
125 using VectorScaleOpConversion =
126 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
127
128 /// Conversion pattern for a vector.bitcast.
129 class VectorBitCastOpConversion
130 : public ConvertOpToLLVMPattern<vector::BitCastOp> {
131 public:
132 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
133
134 LogicalResult
matchAndRewrite(vector::BitCastOp bitCastOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const135 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
136 ConversionPatternRewriter &rewriter) const override {
137 // Only 0-D and 1-D vectors can be lowered to LLVM.
138 VectorType resultTy = bitCastOp.getResultVectorType();
139 if (resultTy.getRank() > 1)
140 return failure();
141 Type newResultTy = typeConverter->convertType(resultTy);
142 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
143 adaptor.getOperands()[0]);
144 return success();
145 }
146 };
147
148 /// Conversion pattern for a vector.matrix_multiply.
149 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
150 class VectorMatmulOpConversion
151 : public ConvertOpToLLVMPattern<vector::MatmulOp> {
152 public:
153 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
154
155 LogicalResult
matchAndRewrite(vector::MatmulOp matmulOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const156 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
157 ConversionPatternRewriter &rewriter) const override {
158 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
159 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
160 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
161 matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
162 return success();
163 }
164 };
165
166 /// Conversion pattern for a vector.flat_transpose.
167 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
168 class VectorFlatTransposeOpConversion
169 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
170 public:
171 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
172
173 LogicalResult
matchAndRewrite(vector::FlatTransposeOp transOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const174 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
175 ConversionPatternRewriter &rewriter) const override {
176 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
177 transOp, typeConverter->convertType(transOp.getRes().getType()),
178 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
179 return success();
180 }
181 };
182
183 /// Overloaded utility that replaces a vector.load, vector.store,
184 /// vector.maskedload and vector.maskedstore with their respective LLVM
185 /// couterparts.
replaceLoadOrStoreOp(vector::LoadOp loadOp,vector::LoadOpAdaptor adaptor,VectorType vectorTy,Value ptr,unsigned align,ConversionPatternRewriter & rewriter)186 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
187 vector::LoadOpAdaptor adaptor,
188 VectorType vectorTy, Value ptr, unsigned align,
189 ConversionPatternRewriter &rewriter) {
190 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
191 }
192
replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,vector::MaskedLoadOpAdaptor adaptor,VectorType vectorTy,Value ptr,unsigned align,ConversionPatternRewriter & rewriter)193 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
194 vector::MaskedLoadOpAdaptor adaptor,
195 VectorType vectorTy, Value ptr, unsigned align,
196 ConversionPatternRewriter &rewriter) {
197 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
198 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
199 }
200
replaceLoadOrStoreOp(vector::StoreOp storeOp,vector::StoreOpAdaptor adaptor,VectorType vectorTy,Value ptr,unsigned align,ConversionPatternRewriter & rewriter)201 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
202 vector::StoreOpAdaptor adaptor,
203 VectorType vectorTy, Value ptr, unsigned align,
204 ConversionPatternRewriter &rewriter) {
205 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
206 ptr, align);
207 }
208
replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,vector::MaskedStoreOpAdaptor adaptor,VectorType vectorTy,Value ptr,unsigned align,ConversionPatternRewriter & rewriter)209 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
210 vector::MaskedStoreOpAdaptor adaptor,
211 VectorType vectorTy, Value ptr, unsigned align,
212 ConversionPatternRewriter &rewriter) {
213 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
214 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
215 }
216
217 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
218 /// vector.maskedstore.
219 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
220 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
221 public:
222 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
223
224 LogicalResult
matchAndRewrite(LoadOrStoreOp loadOrStoreOp,typename LoadOrStoreOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const225 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
226 typename LoadOrStoreOp::Adaptor adaptor,
227 ConversionPatternRewriter &rewriter) const override {
228 // Only 1-D vectors can be lowered to LLVM.
229 VectorType vectorTy = loadOrStoreOp.getVectorType();
230 if (vectorTy.getRank() > 1)
231 return failure();
232
233 auto loc = loadOrStoreOp->getLoc();
234 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
235
236 // Resolve alignment.
237 unsigned align;
238 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
239 return failure();
240
241 // Resolve address.
242 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
243 .template cast<VectorType>();
244 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
245 adaptor.getIndices(), rewriter);
246 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
247
248 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
249 return success();
250 }
251 };
252
253 /// Conversion pattern for a vector.gather.
254 class VectorGatherOpConversion
255 : public ConvertOpToLLVMPattern<vector::GatherOp> {
256 public:
257 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
258
259 LogicalResult
matchAndRewrite(vector::GatherOp gather,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const260 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
261 ConversionPatternRewriter &rewriter) const override {
262 auto loc = gather->getLoc();
263 MemRefType memRefType = gather.getMemRefType();
264
265 // Resolve alignment.
266 unsigned align;
267 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
268 return failure();
269
270 // Resolve address.
271 Value ptrs;
272 VectorType vType = gather.getVectorType();
273 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
274 adaptor.getIndices(), rewriter);
275 if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
276 adaptor.getIndexVec(), memRefType, vType, ptrs)))
277 return failure();
278
279 // Replace with the gather intrinsic.
280 rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
281 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
282 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
283 return success();
284 }
285 };
286
287 /// Conversion pattern for a vector.scatter.
288 class VectorScatterOpConversion
289 : public ConvertOpToLLVMPattern<vector::ScatterOp> {
290 public:
291 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
292
293 LogicalResult
matchAndRewrite(vector::ScatterOp scatter,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const294 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter) const override {
296 auto loc = scatter->getLoc();
297 MemRefType memRefType = scatter.getMemRefType();
298
299 // Resolve alignment.
300 unsigned align;
301 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
302 return failure();
303
304 // Resolve address.
305 Value ptrs;
306 VectorType vType = scatter.getVectorType();
307 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
308 adaptor.getIndices(), rewriter);
309 if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
310 adaptor.getIndexVec(), memRefType, vType, ptrs)))
311 return failure();
312
313 // Replace with the scatter intrinsic.
314 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
315 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
316 rewriter.getI32IntegerAttr(align));
317 return success();
318 }
319 };
320
321 /// Conversion pattern for a vector.expandload.
322 class VectorExpandLoadOpConversion
323 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
324 public:
325 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
326
327 LogicalResult
matchAndRewrite(vector::ExpandLoadOp expand,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const328 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter) const override {
330 auto loc = expand->getLoc();
331 MemRefType memRefType = expand.getMemRefType();
332
333 // Resolve address.
334 auto vtype = typeConverter->convertType(expand.getVectorType());
335 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
336 adaptor.getIndices(), rewriter);
337
338 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
339 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
340 return success();
341 }
342 };
343
344 /// Conversion pattern for a vector.compressstore.
345 class VectorCompressStoreOpConversion
346 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
347 public:
348 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
349
350 LogicalResult
matchAndRewrite(vector::CompressStoreOp compress,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const351 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
352 ConversionPatternRewriter &rewriter) const override {
353 auto loc = compress->getLoc();
354 MemRefType memRefType = compress.getMemRefType();
355
356 // Resolve address.
357 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
358 adaptor.getIndices(), rewriter);
359
360 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
361 compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
362 return success();
363 }
364 };
365
366 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
367 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
368 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
369 /// non-null.
370 template <class VectorOp, class ScalarOp>
createIntegerReductionArithmeticOpLowering(ConversionPatternRewriter & rewriter,Location loc,Type llvmType,Value vectorOperand,Value accumulator)371 static Value createIntegerReductionArithmeticOpLowering(
372 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
373 Value vectorOperand, Value accumulator) {
374 Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
375 if (accumulator)
376 result = rewriter.create<ScalarOp>(loc, accumulator, result);
377 return result;
378 }
379
380 /// Helper method to lower a `vector.reduction` operation that performs
381 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
382 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
383 /// the accumulator value if non-null.
384 template <class VectorOp>
createIntegerReductionComparisonOpLowering(ConversionPatternRewriter & rewriter,Location loc,Type llvmType,Value vectorOperand,Value accumulator,LLVM::ICmpPredicate predicate)385 static Value createIntegerReductionComparisonOpLowering(
386 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
387 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
388 Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
389 if (accumulator) {
390 Value cmp =
391 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
392 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
393 }
394 return result;
395 }
396
397 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum
398 /// with vector types.
createMinMaxF(OpBuilder & builder,Location loc,Value lhs,Value rhs,bool isMin)399 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
400 Value rhs, bool isMin) {
401 auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
402 Type i1Type = builder.getI1Type();
403 if (auto vecType = lhs.getType().dyn_cast<VectorType>())
404 i1Type = VectorType::get(vecType.getShape(), i1Type);
405 Value cmp = builder.create<LLVM::FCmpOp>(
406 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
407 lhs, rhs);
408 Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
409 Value isNan = builder.create<LLVM::FCmpOp>(
410 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
411 Value nan = builder.create<LLVM::ConstantOp>(
412 loc, lhs.getType(),
413 builder.getFloatAttr(floatType,
414 APFloat::getQNaN(floatType.getFloatSemantics())));
415 return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
416 }
417
418 /// Conversion pattern for all vector reductions.
419 class VectorReductionOpConversion
420 : public ConvertOpToLLVMPattern<vector::ReductionOp> {
421 public:
VectorReductionOpConversion(LLVMTypeConverter & typeConv,bool reassociateFPRed)422 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
423 bool reassociateFPRed)
424 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
425 reassociateFPReductions(reassociateFPRed) {}
426
427 LogicalResult
matchAndRewrite(vector::ReductionOp reductionOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const428 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
429 ConversionPatternRewriter &rewriter) const override {
430 auto kind = reductionOp.getKind();
431 Type eltType = reductionOp.getDest().getType();
432 Type llvmType = typeConverter->convertType(eltType);
433 Value operand = adaptor.getVector();
434 Value acc = adaptor.getAcc();
435 Location loc = reductionOp.getLoc();
436 if (eltType.isIntOrIndex()) {
437 // Integer reductions: add/mul/min/max/and/or/xor.
438 Value result;
439 switch (kind) {
440 case vector::CombiningKind::ADD:
441 result =
442 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
443 LLVM::AddOp>(
444 rewriter, loc, llvmType, operand, acc);
445 break;
446 case vector::CombiningKind::MUL:
447 result =
448 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
449 LLVM::MulOp>(
450 rewriter, loc, llvmType, operand, acc);
451 break;
452 case vector::CombiningKind::MINUI:
453 result = createIntegerReductionComparisonOpLowering<
454 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
455 LLVM::ICmpPredicate::ule);
456 break;
457 case vector::CombiningKind::MINSI:
458 result = createIntegerReductionComparisonOpLowering<
459 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
460 LLVM::ICmpPredicate::sle);
461 break;
462 case vector::CombiningKind::MAXUI:
463 result = createIntegerReductionComparisonOpLowering<
464 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
465 LLVM::ICmpPredicate::uge);
466 break;
467 case vector::CombiningKind::MAXSI:
468 result = createIntegerReductionComparisonOpLowering<
469 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
470 LLVM::ICmpPredicate::sge);
471 break;
472 case vector::CombiningKind::AND:
473 result =
474 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
475 LLVM::AndOp>(
476 rewriter, loc, llvmType, operand, acc);
477 break;
478 case vector::CombiningKind::OR:
479 result =
480 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
481 LLVM::OrOp>(
482 rewriter, loc, llvmType, operand, acc);
483 break;
484 case vector::CombiningKind::XOR:
485 result =
486 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
487 LLVM::XOrOp>(
488 rewriter, loc, llvmType, operand, acc);
489 break;
490 default:
491 return failure();
492 }
493 rewriter.replaceOp(reductionOp, result);
494
495 return success();
496 }
497
498 if (!eltType.isa<FloatType>())
499 return failure();
500
501 // Floating-point reductions: add/mul/min/max
502 if (kind == vector::CombiningKind::ADD) {
503 // Optional accumulator (or zero).
504 Value acc = adaptor.getOperands().size() > 1
505 ? adaptor.getOperands()[1]
506 : rewriter.create<LLVM::ConstantOp>(
507 reductionOp->getLoc(), llvmType,
508 rewriter.getZeroAttr(eltType));
509 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
510 reductionOp, llvmType, acc, operand,
511 rewriter.getBoolAttr(reassociateFPReductions));
512 } else if (kind == vector::CombiningKind::MUL) {
513 // Optional accumulator (or one).
514 Value acc = adaptor.getOperands().size() > 1
515 ? adaptor.getOperands()[1]
516 : rewriter.create<LLVM::ConstantOp>(
517 reductionOp->getLoc(), llvmType,
518 rewriter.getFloatAttr(eltType, 1.0));
519 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
520 reductionOp, llvmType, acc, operand,
521 rewriter.getBoolAttr(reassociateFPReductions));
522 } else if (kind == vector::CombiningKind::MINF) {
523 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
524 // NaNs/-0.0/+0.0 in the same way.
525 Value result =
526 rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand);
527 if (acc)
528 result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true);
529 rewriter.replaceOp(reductionOp, result);
530 } else if (kind == vector::CombiningKind::MAXF) {
531 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
532 // NaNs/-0.0/+0.0 in the same way.
533 Value result =
534 rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand);
535 if (acc)
536 result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false);
537 rewriter.replaceOp(reductionOp, result);
538 } else
539 return failure();
540
541 return success();
542 }
543
544 private:
545 const bool reassociateFPReductions;
546 };
547
548 class VectorShuffleOpConversion
549 : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
550 public:
551 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
552
553 LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const554 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
555 ConversionPatternRewriter &rewriter) const override {
556 auto loc = shuffleOp->getLoc();
557 auto v1Type = shuffleOp.getV1VectorType();
558 auto v2Type = shuffleOp.getV2VectorType();
559 auto vectorType = shuffleOp.getVectorType();
560 Type llvmType = typeConverter->convertType(vectorType);
561 auto maskArrayAttr = shuffleOp.getMask();
562
563 // Bail if result type cannot be lowered.
564 if (!llvmType)
565 return failure();
566
567 // Get rank and dimension sizes.
568 int64_t rank = vectorType.getRank();
569 assert(v1Type.getRank() == rank);
570 assert(v2Type.getRank() == rank);
571 int64_t v1Dim = v1Type.getDimSize(0);
572
573 // For rank 1, where both operands have *exactly* the same vector type,
574 // there is direct shuffle support in LLVM. Use it!
575 if (rank == 1 && v1Type == v2Type) {
576 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
577 loc, adaptor.getV1(), adaptor.getV2(), maskArrayAttr);
578 rewriter.replaceOp(shuffleOp, llvmShuffleOp);
579 return success();
580 }
581
582 // For all other cases, insert the individual values individually.
583 Type eltType;
584 if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
585 eltType = arrayType.getElementType();
586 else
587 eltType = llvmType.cast<VectorType>().getElementType();
588 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
589 int64_t insPos = 0;
590 for (const auto &en : llvm::enumerate(maskArrayAttr)) {
591 int64_t extPos = en.value().cast<IntegerAttr>().getInt();
592 Value value = adaptor.getV1();
593 if (extPos >= v1Dim) {
594 extPos -= v1Dim;
595 value = adaptor.getV2();
596 }
597 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
598 eltType, rank, extPos);
599 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
600 llvmType, rank, insPos++);
601 }
602 rewriter.replaceOp(shuffleOp, insert);
603 return success();
604 }
605 };
606
607 class VectorExtractElementOpConversion
608 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
609 public:
610 using ConvertOpToLLVMPattern<
611 vector::ExtractElementOp>::ConvertOpToLLVMPattern;
612
613 LogicalResult
matchAndRewrite(vector::ExtractElementOp extractEltOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const614 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
615 ConversionPatternRewriter &rewriter) const override {
616 auto vectorType = extractEltOp.getVectorType();
617 auto llvmType = typeConverter->convertType(vectorType.getElementType());
618
619 // Bail if result type cannot be lowered.
620 if (!llvmType)
621 return failure();
622
623 if (vectorType.getRank() == 0) {
624 Location loc = extractEltOp.getLoc();
625 auto idxType = rewriter.getIndexType();
626 auto zero = rewriter.create<LLVM::ConstantOp>(
627 loc, typeConverter->convertType(idxType),
628 rewriter.getIntegerAttr(idxType, 0));
629 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
630 extractEltOp, llvmType, adaptor.getVector(), zero);
631 return success();
632 }
633
634 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
635 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
636 return success();
637 }
638 };
639
640 class VectorExtractOpConversion
641 : public ConvertOpToLLVMPattern<vector::ExtractOp> {
642 public:
643 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
644
645 LogicalResult
matchAndRewrite(vector::ExtractOp extractOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const646 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
647 ConversionPatternRewriter &rewriter) const override {
648 auto loc = extractOp->getLoc();
649 auto vectorType = extractOp.getVectorType();
650 auto resultType = extractOp.getResult().getType();
651 auto llvmResultType = typeConverter->convertType(resultType);
652 auto positionArrayAttr = extractOp.getPosition();
653
654 // Bail if result type cannot be lowered.
655 if (!llvmResultType)
656 return failure();
657
658 // Extract entire vector. Should be handled by folder, but just to be safe.
659 if (positionArrayAttr.empty()) {
660 rewriter.replaceOp(extractOp, adaptor.getVector());
661 return success();
662 }
663
664 // One-shot extraction of vector from array (only requires extractvalue).
665 if (resultType.isa<VectorType>()) {
666 Value extracted = rewriter.create<LLVM::ExtractValueOp>(
667 loc, llvmResultType, adaptor.getVector(), positionArrayAttr);
668 rewriter.replaceOp(extractOp, extracted);
669 return success();
670 }
671
672 // Potential extraction of 1-D vector from array.
673 auto *context = extractOp->getContext();
674 Value extracted = adaptor.getVector();
675 auto positionAttrs = positionArrayAttr.getValue();
676 if (positionAttrs.size() > 1) {
677 auto oneDVectorType = reducedVectorTypeBack(vectorType);
678 auto nMinusOnePositionAttrs =
679 ArrayAttr::get(context, positionAttrs.drop_back());
680 extracted = rewriter.create<LLVM::ExtractValueOp>(
681 loc, typeConverter->convertType(oneDVectorType), extracted,
682 nMinusOnePositionAttrs);
683 }
684
685 // Remaining extraction of element from 1-D LLVM vector
686 auto position = positionAttrs.back().cast<IntegerAttr>();
687 auto i64Type = IntegerType::get(rewriter.getContext(), 64);
688 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
689 extracted =
690 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
691 rewriter.replaceOp(extractOp, extracted);
692
693 return success();
694 }
695 };
696
697 /// Conversion pattern that turns a vector.fma on a 1-D vector
698 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
699 /// This does not match vectors of n >= 2 rank.
700 ///
701 /// Example:
702 /// ```
703 /// vector.fma %a, %a, %a : vector<8xf32>
704 /// ```
705 /// is converted to:
706 /// ```
707 /// llvm.intr.fmuladd %va, %va, %va:
708 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
709 /// -> !llvm."<8 x f32>">
710 /// ```
711 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
712 public:
713 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
714
715 LogicalResult
matchAndRewrite(vector::FMAOp fmaOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const716 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
717 ConversionPatternRewriter &rewriter) const override {
718 VectorType vType = fmaOp.getVectorType();
719 if (vType.getRank() != 1)
720 return failure();
721 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
722 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
723 return success();
724 }
725 };
726
727 class VectorInsertElementOpConversion
728 : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
729 public:
730 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
731
732 LogicalResult
matchAndRewrite(vector::InsertElementOp insertEltOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const733 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
734 ConversionPatternRewriter &rewriter) const override {
735 auto vectorType = insertEltOp.getDestVectorType();
736 auto llvmType = typeConverter->convertType(vectorType);
737
738 // Bail if result type cannot be lowered.
739 if (!llvmType)
740 return failure();
741
742 if (vectorType.getRank() == 0) {
743 Location loc = insertEltOp.getLoc();
744 auto idxType = rewriter.getIndexType();
745 auto zero = rewriter.create<LLVM::ConstantOp>(
746 loc, typeConverter->convertType(idxType),
747 rewriter.getIntegerAttr(idxType, 0));
748 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
749 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
750 return success();
751 }
752
753 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
754 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
755 adaptor.getPosition());
756 return success();
757 }
758 };
759
760 class VectorInsertOpConversion
761 : public ConvertOpToLLVMPattern<vector::InsertOp> {
762 public:
763 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
764
765 LogicalResult
matchAndRewrite(vector::InsertOp insertOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const766 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
767 ConversionPatternRewriter &rewriter) const override {
768 auto loc = insertOp->getLoc();
769 auto sourceType = insertOp.getSourceType();
770 auto destVectorType = insertOp.getDestVectorType();
771 auto llvmResultType = typeConverter->convertType(destVectorType);
772 auto positionArrayAttr = insertOp.getPosition();
773
774 // Bail if result type cannot be lowered.
775 if (!llvmResultType)
776 return failure();
777
778 // Overwrite entire vector with value. Should be handled by folder, but
779 // just to be safe.
780 if (positionArrayAttr.empty()) {
781 rewriter.replaceOp(insertOp, adaptor.getSource());
782 return success();
783 }
784
785 // One-shot insertion of a vector into an array (only requires insertvalue).
786 if (sourceType.isa<VectorType>()) {
787 Value inserted = rewriter.create<LLVM::InsertValueOp>(
788 loc, llvmResultType, adaptor.getDest(), adaptor.getSource(),
789 positionArrayAttr);
790 rewriter.replaceOp(insertOp, inserted);
791 return success();
792 }
793
794 // Potential extraction of 1-D vector from array.
795 auto *context = insertOp->getContext();
796 Value extracted = adaptor.getDest();
797 auto positionAttrs = positionArrayAttr.getValue();
798 auto position = positionAttrs.back().cast<IntegerAttr>();
799 auto oneDVectorType = destVectorType;
800 if (positionAttrs.size() > 1) {
801 oneDVectorType = reducedVectorTypeBack(destVectorType);
802 auto nMinusOnePositionAttrs =
803 ArrayAttr::get(context, positionAttrs.drop_back());
804 extracted = rewriter.create<LLVM::ExtractValueOp>(
805 loc, typeConverter->convertType(oneDVectorType), extracted,
806 nMinusOnePositionAttrs);
807 }
808
809 // Insertion of an element into a 1-D LLVM vector.
810 auto i64Type = IntegerType::get(rewriter.getContext(), 64);
811 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
812 Value inserted = rewriter.create<LLVM::InsertElementOp>(
813 loc, typeConverter->convertType(oneDVectorType), extracted,
814 adaptor.getSource(), constant);
815
816 // Potential insertion of resulting 1-D vector into array.
817 if (positionAttrs.size() > 1) {
818 auto nMinusOnePositionAttrs =
819 ArrayAttr::get(context, positionAttrs.drop_back());
820 inserted = rewriter.create<LLVM::InsertValueOp>(
821 loc, llvmResultType, adaptor.getDest(), inserted,
822 nMinusOnePositionAttrs);
823 }
824
825 rewriter.replaceOp(insertOp, inserted);
826 return success();
827 }
828 };
829
830 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
831 ///
832 /// Example:
833 /// ```
834 /// %d = vector.fma %a, %b, %c : vector<2x4xf32>
835 /// ```
836 /// is rewritten into:
837 /// ```
838 /// %r = splat %f0: vector<2x4xf32>
839 /// %va = vector.extractvalue %a[0] : vector<2x4xf32>
840 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
841 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
842 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
843 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
844 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
845 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
846 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
847 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
848 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
849 /// // %r3 holds the final value.
850 /// ```
851 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
852 public:
853 using OpRewritePattern<FMAOp>::OpRewritePattern;
854
initialize()855 void initialize() {
856 // This pattern recursively unpacks one dimension at a time. The recursion
857 // bounded as the rank is strictly decreasing.
858 setHasBoundedRewriteRecursion();
859 }
860
matchAndRewrite(FMAOp op,PatternRewriter & rewriter) const861 LogicalResult matchAndRewrite(FMAOp op,
862 PatternRewriter &rewriter) const override {
863 auto vType = op.getVectorType();
864 if (vType.getRank() < 2)
865 return failure();
866
867 auto loc = op.getLoc();
868 auto elemType = vType.getElementType();
869 Value zero = rewriter.create<arith::ConstantOp>(
870 loc, elemType, rewriter.getZeroAttr(elemType));
871 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
872 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
873 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
874 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
875 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
876 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
877 desc = rewriter.create<InsertOp>(loc, fma, desc, i);
878 }
879 rewriter.replaceOp(op, desc);
880 return success();
881 }
882 };
883
884 /// Returns the strides if the memory underlying `memRefType` has a contiguous
885 /// static layout.
886 static llvm::Optional<SmallVector<int64_t, 4>>
computeContiguousStrides(MemRefType memRefType)887 computeContiguousStrides(MemRefType memRefType) {
888 int64_t offset;
889 SmallVector<int64_t, 4> strides;
890 if (failed(getStridesAndOffset(memRefType, strides, offset)))
891 return None;
892 if (!strides.empty() && strides.back() != 1)
893 return None;
894 // If no layout or identity layout, this is contiguous by definition.
895 if (memRefType.getLayout().isIdentity())
896 return strides;
897
898 // Otherwise, we must determine contiguity form shapes. This can only ever
899 // work in static cases because MemRefType is underspecified to represent
900 // contiguous dynamic shapes in other ways than with just empty/identity
901 // layout.
902 auto sizes = memRefType.getShape();
903 for (int index = 0, e = strides.size() - 1; index < e; ++index) {
904 if (ShapedType::isDynamic(sizes[index + 1]) ||
905 ShapedType::isDynamicStrideOrOffset(strides[index]) ||
906 ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
907 return None;
908 if (strides[index] != strides[index + 1] * sizes[index + 1])
909 return None;
910 }
911 return strides;
912 }
913
914 class VectorTypeCastOpConversion
915 : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
916 public:
917 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
918
919 LogicalResult
matchAndRewrite(vector::TypeCastOp castOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const920 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
921 ConversionPatternRewriter &rewriter) const override {
922 auto loc = castOp->getLoc();
923 MemRefType sourceMemRefType =
924 castOp.getOperand().getType().cast<MemRefType>();
925 MemRefType targetMemRefType = castOp.getType();
926
927 // Only static shape casts supported atm.
928 if (!sourceMemRefType.hasStaticShape() ||
929 !targetMemRefType.hasStaticShape())
930 return failure();
931
932 auto llvmSourceDescriptorTy =
933 adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
934 if (!llvmSourceDescriptorTy)
935 return failure();
936 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
937
938 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
939 .dyn_cast_or_null<LLVM::LLVMStructType>();
940 if (!llvmTargetDescriptorTy)
941 return failure();
942
943 // Only contiguous source buffers supported atm.
944 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
945 if (!sourceStrides)
946 return failure();
947 auto targetStrides = computeContiguousStrides(targetMemRefType);
948 if (!targetStrides)
949 return failure();
950 // Only support static strides for now, regardless of contiguity.
951 if (llvm::any_of(*targetStrides, ShapedType::isDynamicStrideOrOffset))
952 return failure();
953
954 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
955
956 // Create descriptor.
957 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
958 Type llvmTargetElementTy = desc.getElementPtrType();
959 // Set allocated ptr.
960 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
961 allocated =
962 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
963 desc.setAllocatedPtr(rewriter, loc, allocated);
964 // Set aligned ptr.
965 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
966 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
967 desc.setAlignedPtr(rewriter, loc, ptr);
968 // Fill offset 0.
969 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
970 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
971 desc.setOffset(rewriter, loc, zero);
972
973 // Fill size and stride descriptors in memref.
974 for (const auto &indexedSize :
975 llvm::enumerate(targetMemRefType.getShape())) {
976 int64_t index = indexedSize.index();
977 auto sizeAttr =
978 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
979 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
980 desc.setSize(rewriter, loc, index, size);
981 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
982 (*targetStrides)[index]);
983 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
984 desc.setStride(rewriter, loc, index, stride);
985 }
986
987 rewriter.replaceOp(castOp, {desc});
988 return success();
989 }
990 };
991
992 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
993 /// Non-scalable versions of this operation are handled in Vector Transforms.
994 class VectorCreateMaskOpRewritePattern
995 : public OpRewritePattern<vector::CreateMaskOp> {
996 public:
VectorCreateMaskOpRewritePattern(MLIRContext * context,bool enableIndexOpt)997 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
998 bool enableIndexOpt)
999 : OpRewritePattern<vector::CreateMaskOp>(context),
1000 force32BitVectorIndices(enableIndexOpt) {}
1001
matchAndRewrite(vector::CreateMaskOp op,PatternRewriter & rewriter) const1002 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1003 PatternRewriter &rewriter) const override {
1004 auto dstType = op.getType();
1005 if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
1006 return failure();
1007 IntegerType idxType =
1008 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1009 auto loc = op->getLoc();
1010 Value indices = rewriter.create<LLVM::StepVectorOp>(
1011 loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1012 /*isScalable=*/true));
1013 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1014 op.getOperand(0));
1015 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1016 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1017 indices, bounds);
1018 rewriter.replaceOp(op, comp);
1019 return success();
1020 }
1021
1022 private:
1023 const bool force32BitVectorIndices;
1024 };
1025
1026 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1027 public:
1028 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1029
1030 // Proof-of-concept lowering implementation that relies on a small
1031 // runtime support library, which only needs to provide a few
1032 // printing methods (single value for all data types, opening/closing
1033 // bracket, comma, newline). The lowering fully unrolls a vector
1034 // in terms of these elementary printing operations. The advantage
1035 // of this approach is that the library can remain unaware of all
1036 // low-level implementation details of vectors while still supporting
1037 // output of any shaped and dimensioned vector. Due to full unrolling,
1038 // this approach is less suited for very large vectors though.
1039 //
1040 // TODO: rely solely on libc in future? something else?
1041 //
1042 LogicalResult
matchAndRewrite(vector::PrintOp printOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1043 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1044 ConversionPatternRewriter &rewriter) const override {
1045 Type printType = printOp.getPrintType();
1046
1047 if (typeConverter->convertType(printType) == nullptr)
1048 return failure();
1049
1050 // Make sure element type has runtime support.
1051 PrintConversion conversion = PrintConversion::None;
1052 VectorType vectorType = printType.dyn_cast<VectorType>();
1053 Type eltType = vectorType ? vectorType.getElementType() : printType;
1054 Operation *printer;
1055 if (eltType.isF32()) {
1056 printer =
1057 LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
1058 } else if (eltType.isF64()) {
1059 printer =
1060 LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
1061 } else if (eltType.isIndex()) {
1062 printer =
1063 LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
1064 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1065 // Integers need a zero or sign extension on the operand
1066 // (depending on the source type) as well as a signed or
1067 // unsigned print method. Up to 64-bit is supported.
1068 unsigned width = intTy.getWidth();
1069 if (intTy.isUnsigned()) {
1070 if (width <= 64) {
1071 if (width < 64)
1072 conversion = PrintConversion::ZeroExt64;
1073 printer = LLVM::lookupOrCreatePrintU64Fn(
1074 printOp->getParentOfType<ModuleOp>());
1075 } else {
1076 return failure();
1077 }
1078 } else {
1079 assert(intTy.isSignless() || intTy.isSigned());
1080 if (width <= 64) {
1081 // Note that we *always* zero extend booleans (1-bit integers),
1082 // so that true/false is printed as 1/0 rather than -1/0.
1083 if (width == 1)
1084 conversion = PrintConversion::ZeroExt64;
1085 else if (width < 64)
1086 conversion = PrintConversion::SignExt64;
1087 printer = LLVM::lookupOrCreatePrintI64Fn(
1088 printOp->getParentOfType<ModuleOp>());
1089 } else {
1090 return failure();
1091 }
1092 }
1093 } else {
1094 return failure();
1095 }
1096
1097 // Unroll vector into elementary print calls.
1098 int64_t rank = vectorType ? vectorType.getRank() : 0;
1099 Type type = vectorType ? vectorType : eltType;
1100 emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
1101 conversion);
1102 emitCall(rewriter, printOp->getLoc(),
1103 LLVM::lookupOrCreatePrintNewlineFn(
1104 printOp->getParentOfType<ModuleOp>()));
1105 rewriter.eraseOp(printOp);
1106 return success();
1107 }
1108
1109 private:
1110 enum class PrintConversion {
1111 // clang-format off
1112 None,
1113 ZeroExt64,
1114 SignExt64
1115 // clang-format on
1116 };
1117
emitRanks(ConversionPatternRewriter & rewriter,Operation * op,Value value,Type type,Operation * printer,int64_t rank,PrintConversion conversion) const1118 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1119 Value value, Type type, Operation *printer, int64_t rank,
1120 PrintConversion conversion) const {
1121 VectorType vectorType = type.dyn_cast<VectorType>();
1122 Location loc = op->getLoc();
1123 if (!vectorType) {
1124 assert(rank == 0 && "The scalar case expects rank == 0");
1125 switch (conversion) {
1126 case PrintConversion::ZeroExt64:
1127 value = rewriter.create<arith::ExtUIOp>(
1128 loc, IntegerType::get(rewriter.getContext(), 64), value);
1129 break;
1130 case PrintConversion::SignExt64:
1131 value = rewriter.create<arith::ExtSIOp>(
1132 loc, IntegerType::get(rewriter.getContext(), 64), value);
1133 break;
1134 case PrintConversion::None:
1135 break;
1136 }
1137 emitCall(rewriter, loc, printer, value);
1138 return;
1139 }
1140
1141 emitCall(rewriter, loc,
1142 LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1143 Operation *printComma =
1144 LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1145
1146 if (rank <= 1) {
1147 auto reducedType = vectorType.getElementType();
1148 auto llvmType = typeConverter->convertType(reducedType);
1149 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
1150 for (int64_t d = 0; d < dim; ++d) {
1151 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1152 llvmType, /*rank=*/0, /*pos=*/d);
1153 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
1154 conversion);
1155 if (d != dim - 1)
1156 emitCall(rewriter, loc, printComma);
1157 }
1158 emitCall(
1159 rewriter, loc,
1160 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1161 return;
1162 }
1163
1164 int64_t dim = vectorType.getDimSize(0);
1165 for (int64_t d = 0; d < dim; ++d) {
1166 auto reducedType = reducedVectorTypeFront(vectorType);
1167 auto llvmType = typeConverter->convertType(reducedType);
1168 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1169 llvmType, rank, d);
1170 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1171 conversion);
1172 if (d != dim - 1)
1173 emitCall(rewriter, loc, printComma);
1174 }
1175 emitCall(rewriter, loc,
1176 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1177 }
1178
1179 // Helper to emit a call.
emitCall(ConversionPatternRewriter & rewriter,Location loc,Operation * ref,ValueRange params=ValueRange ())1180 static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1181 Operation *ref, ValueRange params = ValueRange()) {
1182 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1183 params);
1184 }
1185 };
1186
1187 /// The Splat operation is lowered to an insertelement + a shufflevector
1188 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1189 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1190 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1191
1192 LogicalResult
matchAndRewrite__anon9084af800111::VectorSplatOpLowering1193 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1194 ConversionPatternRewriter &rewriter) const override {
1195 VectorType resultType = splatOp.getType().cast<VectorType>();
1196 if (resultType.getRank() > 1)
1197 return failure();
1198
1199 // First insert it into an undef vector so we can shuffle it.
1200 auto vectorType = typeConverter->convertType(splatOp.getType());
1201 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1202 auto zero = rewriter.create<LLVM::ConstantOp>(
1203 splatOp.getLoc(),
1204 typeConverter->convertType(rewriter.getIntegerType(32)),
1205 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1206
1207 // For 0-d vector, we simply do `insertelement`.
1208 if (resultType.getRank() == 0) {
1209 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1210 splatOp, vectorType, undef, adaptor.getInput(), zero);
1211 return success();
1212 }
1213
1214 // For 1-d vector, we additionally do a `vectorshuffle`.
1215 auto v = rewriter.create<LLVM::InsertElementOp>(
1216 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1217
1218 int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
1219 SmallVector<int32_t, 4> zeroValues(width, 0);
1220
1221 // Shuffle the value across the desired number of elements.
1222 ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
1223 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1224 zeroAttrs);
1225 return success();
1226 }
1227 };
1228
1229 /// The Splat operation is lowered to an insertelement + a shufflevector
1230 /// operation. Splat to only 2+-d vector result types are lowered by the
1231 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1232 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1233 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1234
1235 LogicalResult
matchAndRewrite__anon9084af800111::VectorSplatNdOpLowering1236 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1237 ConversionPatternRewriter &rewriter) const override {
1238 VectorType resultType = splatOp.getType();
1239 if (resultType.getRank() <= 1)
1240 return failure();
1241
1242 // First insert it into an undef vector so we can shuffle it.
1243 auto loc = splatOp.getLoc();
1244 auto vectorTypeInfo =
1245 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1246 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1247 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1248 if (!llvmNDVectorTy || !llvm1DVectorTy)
1249 return failure();
1250
1251 // Construct returned value.
1252 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1253
1254 // Construct a 1-D vector with the splatted value that we insert in all the
1255 // places within the returned descriptor.
1256 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1257 auto zero = rewriter.create<LLVM::ConstantOp>(
1258 loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1259 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1260 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1261 adaptor.getInput(), zero);
1262
1263 // Shuffle the value across the desired number of elements.
1264 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1265 SmallVector<int32_t, 4> zeroValues(width, 0);
1266 ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
1267 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
1268
1269 // Iterate of linear index, convert to coords space and insert splatted 1-D
1270 // vector in each position.
1271 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
1272 desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
1273 position);
1274 });
1275 rewriter.replaceOp(splatOp, desc);
1276 return success();
1277 }
1278 };
1279
1280 } // namespace
1281
1282 /// Populate the given list with patterns that convert from Vector to LLVM.
populateVectorToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns,bool reassociateFPReductions,bool force32BitVectorIndices)1283 void mlir::populateVectorToLLVMConversionPatterns(
1284 LLVMTypeConverter &converter, RewritePatternSet &patterns,
1285 bool reassociateFPReductions, bool force32BitVectorIndices) {
1286 MLIRContext *ctx = converter.getDialect()->getContext();
1287 patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1288 populateVectorInsertExtractStridedSliceTransforms(patterns);
1289 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1290 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1291 patterns
1292 .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1293 VectorExtractElementOpConversion, VectorExtractOpConversion,
1294 VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1295 VectorInsertOpConversion, VectorPrintOpConversion,
1296 VectorTypeCastOpConversion, VectorScaleOpConversion,
1297 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1298 VectorLoadStoreConversion<vector::MaskedLoadOp,
1299 vector::MaskedLoadOpAdaptor>,
1300 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1301 VectorLoadStoreConversion<vector::MaskedStoreOp,
1302 vector::MaskedStoreOpAdaptor>,
1303 VectorGatherOpConversion, VectorScatterOpConversion,
1304 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1305 VectorSplatOpLowering, VectorSplatNdOpLowering>(converter);
1306 // Transfer ops with rank > 1 are handled by VectorToSCF.
1307 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1308 }
1309
populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)1310 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1311 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1312 patterns.add<VectorMatmulOpConversion>(converter);
1313 patterns.add<VectorFlatTransposeOpConversion>(converter);
1314 }
1315