1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "llvm/ADT/SmallBitVector.h"
24
25 using namespace mlir;
26
27 namespace {
28
isStaticStrideOrOffset(int64_t strideOrOffset)29 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
30 return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
31 }
32
33 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
AllocOpLowering__anon7a9e10510111::AllocOpLowering34 AllocOpLowering(LLVMTypeConverter &converter)
35 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
36 converter) {}
37
getAllocFn__anon7a9e10510111::AllocOpLowering38 LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
39 bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
40
41 if (useGenericFn)
42 return LLVM::lookupOrCreateGenericAllocFn(module, getIndexType());
43
44 return LLVM::lookupOrCreateMallocFn(module, getIndexType());
45 }
46
allocateBuffer__anon7a9e10510111::AllocOpLowering47 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
48 Location loc, Value sizeBytes,
49 Operation *op) const override {
50 // Heap allocations.
51 memref::AllocOp allocOp = cast<memref::AllocOp>(op);
52 MemRefType memRefType = allocOp.getType();
53
54 Value alignment;
55 if (auto alignmentAttr = allocOp.getAlignment()) {
56 alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
57 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
58 // In the case where no alignment is specified, we may want to override
59 // `malloc's` behavior. `malloc` typically aligns at the size of the
60 // biggest scalar on a target HW. For non-scalars, use the natural
61 // alignment of the LLVM type given by the LLVM DataLayout.
62 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
63 }
64
65 if (alignment) {
66 // Adjust the allocation size to consider alignment.
67 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
68 }
69
70 // Allocate the underlying buffer and store a pointer to it in the MemRef
71 // descriptor.
72 Type elementPtrType = this->getElementPtrType(memRefType);
73 auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
74 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
75 getVoidPtrType());
76 Value allocatedPtr =
77 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
78
79 Value alignedPtr = allocatedPtr;
80 if (alignment) {
81 // Compute the aligned type pointer.
82 Value allocatedInt =
83 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
84 Value alignmentInt =
85 createAligned(rewriter, loc, allocatedInt, alignment);
86 alignedPtr =
87 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
88 }
89
90 return std::make_tuple(allocatedPtr, alignedPtr);
91 }
92 };
93
94 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
AlignedAllocOpLowering__anon7a9e10510111::AlignedAllocOpLowering95 AlignedAllocOpLowering(LLVMTypeConverter &converter)
96 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
97 converter) {}
98
99 /// Returns the memref's element size in bytes using the data layout active at
100 /// `op`.
101 // TODO: there are other places where this is used. Expose publicly?
getMemRefEltSizeInBytes__anon7a9e10510111::AlignedAllocOpLowering102 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
103 const DataLayout *layout = &defaultLayout;
104 if (const DataLayoutAnalysis *analysis =
105 getTypeConverter()->getDataLayoutAnalysis()) {
106 layout = &analysis->getAbove(op);
107 }
108 Type elementType = memRefType.getElementType();
109 if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
110 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
111 *layout);
112 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
113 return getTypeConverter()->getUnrankedMemRefDescriptorSize(
114 memRefElementType, *layout);
115 return layout->getTypeSize(elementType);
116 }
117
118 /// Returns true if the memref size in bytes is known to be a multiple of
119 /// factor assuming the data layout active at `op`.
isMemRefSizeMultipleOf__anon7a9e10510111::AlignedAllocOpLowering120 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
121 Operation *op) const {
122 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
123 for (unsigned i = 0, e = type.getRank(); i < e; i++) {
124 if (ShapedType::isDynamic(type.getDimSize(i)))
125 continue;
126 sizeDivisor = sizeDivisor * type.getDimSize(i);
127 }
128 return sizeDivisor % factor == 0;
129 }
130
131 /// Returns the alignment to be used for the allocation call itself.
132 /// aligned_alloc requires the allocation size to be a power of two, and the
133 /// allocation size to be a multiple of alignment,
getAllocationAlignment__anon7a9e10510111::AlignedAllocOpLowering134 int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
135 if (Optional<uint64_t> alignment = allocOp.getAlignment())
136 return *alignment;
137
138 // Whenever we don't have alignment set, we will use an alignment
139 // consistent with the element type; since the allocation size has to be a
140 // power of two, we will bump to the next power of two if it already isn't.
141 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
142 return std::max(kMinAlignedAllocAlignment,
143 llvm::PowerOf2Ceil(eltSizeBytes));
144 }
145
getAllocFn__anon7a9e10510111::AlignedAllocOpLowering146 LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
147 bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
148
149 if (useGenericFn)
150 return LLVM::lookupOrCreateGenericAlignedAllocFn(module, getIndexType());
151
152 return LLVM::lookupOrCreateAlignedAllocFn(module, getIndexType());
153 }
154
allocateBuffer__anon7a9e10510111::AlignedAllocOpLowering155 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
156 Location loc, Value sizeBytes,
157 Operation *op) const override {
158 // Heap allocations.
159 memref::AllocOp allocOp = cast<memref::AllocOp>(op);
160 MemRefType memRefType = allocOp.getType();
161 int64_t alignment = getAllocationAlignment(allocOp);
162 Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
163
164 // aligned_alloc requires size to be a multiple of alignment; we will pad
165 // the size to the next multiple if necessary.
166 if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
167 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
168
169 Type elementPtrType = this->getElementPtrType(memRefType);
170 auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
171 auto results =
172 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
173 getVoidPtrType());
174 Value allocatedPtr =
175 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
176
177 return std::make_tuple(allocatedPtr, allocatedPtr);
178 }
179
180 /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
181 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
182
183 /// Default layout to use in absence of the corresponding analysis.
184 DataLayout defaultLayout;
185 };
186
187 // Out of line definition, required till C++17.
188 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
189
190 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
AllocaOpLowering__anon7a9e10510111::AllocaOpLowering191 AllocaOpLowering(LLVMTypeConverter &converter)
192 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
193 converter) {}
194
195 /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
196 /// is set to null for stack allocations. `accessAlignment` is set if
197 /// alignment is needed post allocation (for eg. in conjunction with malloc).
allocateBuffer__anon7a9e10510111::AllocaOpLowering198 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
199 Location loc, Value sizeBytes,
200 Operation *op) const override {
201
202 // With alloca, one gets a pointer to the element type right away.
203 // For stack allocations.
204 auto allocaOp = cast<memref::AllocaOp>(op);
205 auto elementPtrType = this->getElementPtrType(allocaOp.getType());
206
207 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
208 loc, elementPtrType, sizeBytes, allocaOp.getAlignment().value_or(0));
209
210 return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
211 }
212 };
213
214 struct AllocaScopeOpLowering
215 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
216 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
217
218 LogicalResult
matchAndRewrite__anon7a9e10510111::AllocaScopeOpLowering219 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
220 ConversionPatternRewriter &rewriter) const override {
221 OpBuilder::InsertionGuard guard(rewriter);
222 Location loc = allocaScopeOp.getLoc();
223
224 // Split the current block before the AllocaScopeOp to create the inlining
225 // point.
226 auto *currentBlock = rewriter.getInsertionBlock();
227 auto *remainingOpsBlock =
228 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
229 Block *continueBlock;
230 if (allocaScopeOp.getNumResults() == 0) {
231 continueBlock = remainingOpsBlock;
232 } else {
233 continueBlock = rewriter.createBlock(
234 remainingOpsBlock, allocaScopeOp.getResultTypes(),
235 SmallVector<Location>(allocaScopeOp->getNumResults(),
236 allocaScopeOp.getLoc()));
237 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
238 }
239
240 // Inline body region.
241 Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
242 Block *afterBody = &allocaScopeOp.getBodyRegion().back();
243 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
244
245 // Save stack and then branch into the body of the region.
246 rewriter.setInsertionPointToEnd(currentBlock);
247 auto stackSaveOp =
248 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
249 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
250
251 // Replace the alloca_scope return with a branch that jumps out of the body.
252 // Stack restore before leaving the body region.
253 rewriter.setInsertionPointToEnd(afterBody);
254 auto returnOp =
255 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
256 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
257 returnOp, returnOp.getResults(), continueBlock);
258
259 // Insert stack restore before jumping out the body of the region.
260 rewriter.setInsertionPoint(branchOp);
261 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
262
263 // Replace the op with values return from the body region.
264 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
265
266 return success();
267 }
268 };
269
270 struct AssumeAlignmentOpLowering
271 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
272 using ConvertOpToLLVMPattern<
273 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
274
275 LogicalResult
matchAndRewrite__anon7a9e10510111::AssumeAlignmentOpLowering276 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
277 ConversionPatternRewriter &rewriter) const override {
278 Value memref = adaptor.getMemref();
279 unsigned alignment = op.getAlignment();
280 auto loc = op.getLoc();
281
282 MemRefDescriptor memRefDescriptor(memref);
283 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
284
285 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
286 // the asserted memref.alignedPtr isn't used anywhere else, as the real
287 // users like load/store/views always re-extract memref.alignedPtr as they
288 // get lowered.
289 //
290 // This relies on LLVM's CSE optimization (potentially after SROA), since
291 // after CSE all memref.alignedPtr instances get de-duplicated into the same
292 // pointer SSA value.
293 auto intPtrType =
294 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
295 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
296 Value mask =
297 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
298 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
299 rewriter.create<LLVM::AssumeOp>(
300 loc, rewriter.create<LLVM::ICmpOp>(
301 loc, LLVM::ICmpPredicate::eq,
302 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
303
304 rewriter.eraseOp(op);
305 return success();
306 }
307 };
308
309 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
310 // The memref descriptor being an SSA value, there is no need to clean it up
311 // in any way.
312 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
313 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
314
DeallocOpLowering__anon7a9e10510111::DeallocOpLowering315 explicit DeallocOpLowering(LLVMTypeConverter &converter)
316 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
317
getFreeFn__anon7a9e10510111::DeallocOpLowering318 LLVM::LLVMFuncOp getFreeFn(ModuleOp module) const {
319 bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
320
321 if (useGenericFn)
322 return LLVM::lookupOrCreateGenericFreeFn(module);
323
324 return LLVM::lookupOrCreateFreeFn(module);
325 }
326
327 LogicalResult
matchAndRewrite__anon7a9e10510111::DeallocOpLowering328 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter) const override {
330 // Insert the `free` declaration if it is not already present.
331 auto freeFunc = getFreeFn(op->getParentOfType<ModuleOp>());
332 MemRefDescriptor memref(adaptor.getMemref());
333 Value casted = rewriter.create<LLVM::BitcastOp>(
334 op.getLoc(), getVoidPtrType(),
335 memref.allocatedPtr(rewriter, op.getLoc()));
336 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
337 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
338 return success();
339 }
340 };
341
342 // A `dim` is converted to a constant for static sizes and to an access to the
343 // size stored in the memref descriptor for dynamic sizes.
344 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
345 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
346
347 LogicalResult
matchAndRewrite__anon7a9e10510111::DimOpLowering348 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter) const override {
350 Type operandType = dimOp.getSource().getType();
351 if (operandType.isa<UnrankedMemRefType>()) {
352 rewriter.replaceOp(
353 dimOp, {extractSizeOfUnrankedMemRef(
354 operandType, dimOp, adaptor.getOperands(), rewriter)});
355
356 return success();
357 }
358 if (operandType.isa<MemRefType>()) {
359 rewriter.replaceOp(
360 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
361 adaptor.getOperands(), rewriter)});
362 return success();
363 }
364 llvm_unreachable("expected MemRefType or UnrankedMemRefType");
365 }
366
367 private:
extractSizeOfUnrankedMemRef__anon7a9e10510111::DimOpLowering368 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
369 OpAdaptor adaptor,
370 ConversionPatternRewriter &rewriter) const {
371 Location loc = dimOp.getLoc();
372
373 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
374 auto scalarMemRefType =
375 MemRefType::get({}, unrankedMemRefType.getElementType());
376 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
377
378 // Extract pointer to the underlying ranked descriptor and bitcast it to a
379 // memref<element_type> descriptor pointer to minimize the number of GEP
380 // operations.
381 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
382 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
383 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
384 loc,
385 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
386 addressSpace),
387 underlyingRankedDesc);
388
389 // Get pointer to offset field of memref<element_type> descriptor.
390 Type indexPtrTy = LLVM::LLVMPointerType::get(
391 getTypeConverter()->getIndexType(), addressSpace);
392 Value two = rewriter.create<LLVM::ConstantOp>(
393 loc, typeConverter->convertType(rewriter.getI32Type()),
394 rewriter.getI32IntegerAttr(2));
395 Value offsetPtr = rewriter.create<LLVM::GEPOp>(
396 loc, indexPtrTy, scalarMemRefDescPtr,
397 ValueRange({createIndexConstant(rewriter, loc, 0), two}));
398
399 // The size value that we have to extract can be obtained using GEPop with
400 // `dimOp.index() + 1` index argument.
401 Value idxPlusOne = rewriter.create<LLVM::AddOp>(
402 loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
403 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
404 ValueRange({idxPlusOne}));
405 return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
406 }
407
getConstantDimIndex__anon7a9e10510111::DimOpLowering408 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
409 if (Optional<int64_t> idx = dimOp.getConstantIndex())
410 return idx;
411
412 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
413 return constantOp.getValue()
414 .cast<IntegerAttr>()
415 .getValue()
416 .getSExtValue();
417
418 return llvm::None;
419 }
420
extractSizeOfRankedMemRef__anon7a9e10510111::DimOpLowering421 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
422 OpAdaptor adaptor,
423 ConversionPatternRewriter &rewriter) const {
424 Location loc = dimOp.getLoc();
425
426 // Take advantage if index is constant.
427 MemRefType memRefType = operandType.cast<MemRefType>();
428 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
429 int64_t i = *index;
430 if (memRefType.isDynamicDim(i)) {
431 // extract dynamic size from the memref descriptor.
432 MemRefDescriptor descriptor(adaptor.getSource());
433 return descriptor.size(rewriter, loc, i);
434 }
435 // Use constant for static size.
436 int64_t dimSize = memRefType.getDimSize(i);
437 return createIndexConstant(rewriter, loc, dimSize);
438 }
439 Value index = adaptor.getIndex();
440 int64_t rank = memRefType.getRank();
441 MemRefDescriptor memrefDescriptor(adaptor.getSource());
442 return memrefDescriptor.size(rewriter, loc, index, rank);
443 }
444 };
445
446 /// Common base for load and store operations on MemRefs. Restricts the match
447 /// to supported MemRef types. Provides functionality to emit code accessing a
448 /// specific element of the underlying data buffer.
449 template <typename Derived>
450 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
451 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
452 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
453 using Base = LoadStoreOpLowering<Derived>;
454
match__anon7a9e10510111::LoadStoreOpLowering455 LogicalResult match(Derived op) const override {
456 MemRefType type = op.getMemRefType();
457 return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
458 }
459 };
460
461 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
462 /// retried until it succeeds in atomically storing a new value into memory.
463 ///
464 /// +---------------------------------+
465 /// | <code before the AtomicRMWOp> |
466 /// | <compute initial %loaded> |
467 /// | cf.br loop(%loaded) |
468 /// +---------------------------------+
469 /// |
470 /// -------| |
471 /// | v v
472 /// | +--------------------------------+
473 /// | | loop(%loaded): |
474 /// | | <body contents> |
475 /// | | %pair = cmpxchg |
476 /// | | %ok = %pair[0] |
477 /// | | %new = %pair[1] |
478 /// | | cf.cond_br %ok, end, loop(%new) |
479 /// | +--------------------------------+
480 /// | | |
481 /// |----------- |
482 /// v
483 /// +--------------------------------+
484 /// | end: |
485 /// | <code after the AtomicRMWOp> |
486 /// +--------------------------------+
487 ///
488 struct GenericAtomicRMWOpLowering
489 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
490 using Base::Base;
491
492 LogicalResult
matchAndRewrite__anon7a9e10510111::GenericAtomicRMWOpLowering493 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
494 ConversionPatternRewriter &rewriter) const override {
495 auto loc = atomicOp.getLoc();
496 Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
497
498 // Split the block into initial, loop, and ending parts.
499 auto *initBlock = rewriter.getInsertionBlock();
500 auto *loopBlock = rewriter.createBlock(
501 initBlock->getParent(), std::next(Region::iterator(initBlock)),
502 valueType, loc);
503 auto *endBlock = rewriter.createBlock(
504 loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
505
506 // Operations range to be moved to `endBlock`.
507 auto opsToMoveStart = atomicOp->getIterator();
508 auto opsToMoveEnd = initBlock->back().getIterator();
509
510 // Compute the loaded value and branch to the loop block.
511 rewriter.setInsertionPointToEnd(initBlock);
512 auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
513 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
514 adaptor.getIndices(), rewriter);
515 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
516 rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
517
518 // Prepare the body of the loop block.
519 rewriter.setInsertionPointToStart(loopBlock);
520
521 // Clone the GenericAtomicRMWOp region and extract the result.
522 auto loopArgument = loopBlock->getArgument(0);
523 BlockAndValueMapping mapping;
524 mapping.map(atomicOp.getCurrentValue(), loopArgument);
525 Block &entryBlock = atomicOp.body().front();
526 for (auto &nestedOp : entryBlock.without_terminator()) {
527 Operation *clone = rewriter.clone(nestedOp, mapping);
528 mapping.map(nestedOp.getResults(), clone->getResults());
529 }
530 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
531
532 // Prepare the epilog of the loop block.
533 // Append the cmpxchg op to the end of the loop block.
534 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
535 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
536 auto boolType = IntegerType::get(rewriter.getContext(), 1);
537 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
538 {valueType, boolType});
539 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
540 loc, pairType, dataPtr, loopArgument, result, successOrdering,
541 failureOrdering);
542 // Extract the %new_loaded and %ok values from the pair.
543 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
544 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
545 Value ok = rewriter.create<LLVM::ExtractValueOp>(
546 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
547
548 // Conditionally branch to the end or back to the loop depending on %ok.
549 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
550 loopBlock, newLoaded);
551
552 rewriter.setInsertionPointToEnd(endBlock);
553 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
554 std::next(opsToMoveEnd), rewriter);
555
556 // The 'result' of the atomic_rmw op is the newly loaded value.
557 rewriter.replaceOp(atomicOp, {newLoaded});
558
559 return success();
560 }
561
562 private:
563 // Clones a segment of ops [start, end) and erases the original.
moveOpsRange__anon7a9e10510111::GenericAtomicRMWOpLowering564 void moveOpsRange(ValueRange oldResult, ValueRange newResult,
565 Block::iterator start, Block::iterator end,
566 ConversionPatternRewriter &rewriter) const {
567 BlockAndValueMapping mapping;
568 mapping.map(oldResult, newResult);
569 SmallVector<Operation *, 2> opsToErase;
570 for (auto it = start; it != end; ++it) {
571 rewriter.clone(*it, mapping);
572 opsToErase.push_back(&*it);
573 }
574 for (auto *it : opsToErase)
575 rewriter.eraseOp(it);
576 }
577 };
578
579 /// Returns the LLVM type of the global variable given the memref type `type`.
convertGlobalMemrefTypeToLLVM(MemRefType type,LLVMTypeConverter & typeConverter)580 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
581 LLVMTypeConverter &typeConverter) {
582 // LLVM type for a global memref will be a multi-dimension array. For
583 // declarations or uninitialized global memrefs, we can potentially flatten
584 // this to a 1D array. However, for memref.global's with an initial value,
585 // we do not intend to flatten the ElementsAttribute when going from std ->
586 // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
587 Type elementType = typeConverter.convertType(type.getElementType());
588 Type arrayTy = elementType;
589 // Shape has the outermost dim at index 0, so need to walk it backwards
590 for (int64_t dim : llvm::reverse(type.getShape()))
591 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
592 return arrayTy;
593 }
594
595 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
596 struct GlobalMemrefOpLowering
597 : public ConvertOpToLLVMPattern<memref::GlobalOp> {
598 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
599
600 LogicalResult
matchAndRewrite__anon7a9e10510111::GlobalMemrefOpLowering601 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
602 ConversionPatternRewriter &rewriter) const override {
603 MemRefType type = global.getType();
604 if (!isConvertibleAndHasIdentityMaps(type))
605 return failure();
606
607 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
608
609 LLVM::Linkage linkage =
610 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
611
612 Attribute initialValue = nullptr;
613 if (!global.isExternal() && !global.isUninitialized()) {
614 auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
615 initialValue = elementsAttr;
616
617 // For scalar memrefs, the global variable created is of the element type,
618 // so unpack the elements attribute to extract the value.
619 if (type.getRank() == 0)
620 initialValue = elementsAttr.getSplatValue<Attribute>();
621 }
622
623 uint64_t alignment = global.getAlignment().value_or(0);
624
625 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
626 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
627 initialValue, alignment, type.getMemorySpaceAsInt());
628 if (!global.isExternal() && global.isUninitialized()) {
629 Block *blk = new Block();
630 newGlobal.getInitializerRegion().push_back(blk);
631 rewriter.setInsertionPointToStart(blk);
632 Value undef[] = {
633 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
634 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
635 }
636 return success();
637 }
638 };
639
640 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
641 /// the first element stashed into the descriptor. This reuses
642 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
643 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
GetGlobalMemrefOpLowering__anon7a9e10510111::GetGlobalMemrefOpLowering644 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
645 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
646 converter) {}
647
648 /// Buffer "allocation" for memref.get_global op is getting the address of
649 /// the global variable referenced.
allocateBuffer__anon7a9e10510111::GetGlobalMemrefOpLowering650 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
651 Location loc, Value sizeBytes,
652 Operation *op) const override {
653 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
654 MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
655 unsigned memSpace = type.getMemorySpaceAsInt();
656
657 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
658 auto addressOf = rewriter.create<LLVM::AddressOfOp>(
659 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace),
660 getGlobalOp.getName());
661
662 // Get the address of the first element in the array by creating a GEP with
663 // the address of the GV as the base, and (rank + 1) number of 0 indices.
664 Type elementType = typeConverter->convertType(type.getElementType());
665 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
666
667 SmallVector<Value> operands;
668 operands.insert(operands.end(), type.getRank() + 1,
669 createIndexConstant(rewriter, loc, 0));
670 auto gep =
671 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
672
673 // We do not expect the memref obtained using `memref.get_global` to be
674 // ever deallocated. Set the allocated pointer to be known bad value to
675 // help debug if that ever happens.
676 auto intPtrType = getIntPtrType(memSpace);
677 Value deadBeefConst =
678 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
679 auto deadBeefPtr =
680 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
681
682 // Both allocated and aligned pointers are same. We could potentially stash
683 // a nullptr for the allocated pointer since we do not expect any dealloc.
684 return std::make_tuple(deadBeefPtr, gep);
685 }
686 };
687
688 // Load operation is lowered to obtaining a pointer to the indexed element
689 // and loading it.
690 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
691 using Base::Base;
692
693 LogicalResult
matchAndRewrite__anon7a9e10510111::LoadOpLowering694 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
695 ConversionPatternRewriter &rewriter) const override {
696 auto type = loadOp.getMemRefType();
697
698 Value dataPtr =
699 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
700 adaptor.getIndices(), rewriter);
701 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
702 return success();
703 }
704 };
705
706 // Store operation is lowered to obtaining a pointer to the indexed element,
707 // and storing the given value to it.
708 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
709 using Base::Base;
710
711 LogicalResult
matchAndRewrite__anon7a9e10510111::StoreOpLowering712 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
713 ConversionPatternRewriter &rewriter) const override {
714 auto type = op.getMemRefType();
715
716 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
717 adaptor.getIndices(), rewriter);
718 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr);
719 return success();
720 }
721 };
722
723 // The prefetch operation is lowered in a way similar to the load operation
724 // except that the llvm.prefetch operation is used for replacement.
725 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
726 using Base::Base;
727
728 LogicalResult
matchAndRewrite__anon7a9e10510111::PrefetchOpLowering729 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
730 ConversionPatternRewriter &rewriter) const override {
731 auto type = prefetchOp.getMemRefType();
732 auto loc = prefetchOp.getLoc();
733
734 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
735 adaptor.getIndices(), rewriter);
736
737 // Replace with llvm.prefetch.
738 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
739 auto isWrite = rewriter.create<LLVM::ConstantOp>(
740 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()));
741 auto localityHint = rewriter.create<LLVM::ConstantOp>(
742 loc, llvmI32Type,
743 rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint()));
744 auto isData = rewriter.create<LLVM::ConstantOp>(
745 loc, llvmI32Type,
746 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache()));
747
748 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
749 localityHint, isData);
750 return success();
751 }
752 };
753
754 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
755 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
756
757 LogicalResult
matchAndRewrite__anon7a9e10510111::RankOpLowering758 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
759 ConversionPatternRewriter &rewriter) const override {
760 Location loc = op.getLoc();
761 Type operandType = op.getMemref().getType();
762 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
763 UnrankedMemRefDescriptor desc(adaptor.getMemref());
764 rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
765 return success();
766 }
767 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
768 rewriter.replaceOp(
769 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
770 return success();
771 }
772 return failure();
773 }
774 };
775
776 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
777 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
778
match__anon7a9e10510111::MemRefCastOpLowering779 LogicalResult match(memref::CastOp memRefCastOp) const override {
780 Type srcType = memRefCastOp.getOperand().getType();
781 Type dstType = memRefCastOp.getType();
782
783 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
784 // used for type erasure. For now they must preserve underlying element type
785 // and require source and result type to have the same rank. Therefore,
786 // perform a sanity check that the underlying structs are the same. Once op
787 // semantics are relaxed we can revisit.
788 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
789 return success(typeConverter->convertType(srcType) ==
790 typeConverter->convertType(dstType));
791
792 // At least one of the operands is unranked type
793 assert(srcType.isa<UnrankedMemRefType>() ||
794 dstType.isa<UnrankedMemRefType>());
795
796 // Unranked to unranked cast is disallowed
797 return !(srcType.isa<UnrankedMemRefType>() &&
798 dstType.isa<UnrankedMemRefType>())
799 ? success()
800 : failure();
801 }
802
rewrite__anon7a9e10510111::MemRefCastOpLowering803 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
804 ConversionPatternRewriter &rewriter) const override {
805 auto srcType = memRefCastOp.getOperand().getType();
806 auto dstType = memRefCastOp.getType();
807 auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
808 auto loc = memRefCastOp.getLoc();
809
810 // For ranked/ranked case, just keep the original descriptor.
811 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
812 return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
813
814 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
815 // Casting ranked to unranked memref type
816 // Set the rank in the destination from the memref type
817 // Allocate space on the stack and copy the src memref descriptor
818 // Set the ptr in the destination to the stack space
819 auto srcMemRefType = srcType.cast<MemRefType>();
820 int64_t rank = srcMemRefType.getRank();
821 // ptr = AllocaOp sizeof(MemRefDescriptor)
822 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
823 loc, adaptor.getSource(), rewriter);
824 // voidptr = BitCastOp srcType* to void*
825 auto voidPtr =
826 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
827 .getResult();
828 // rank = ConstantOp srcRank
829 auto rankVal = rewriter.create<LLVM::ConstantOp>(
830 loc, getIndexType(), rewriter.getIndexAttr(rank));
831 // undef = UndefOp
832 UnrankedMemRefDescriptor memRefDesc =
833 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
834 // d1 = InsertValueOp undef, rank, 0
835 memRefDesc.setRank(rewriter, loc, rankVal);
836 // d2 = InsertValueOp d1, voidptr, 1
837 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
838 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
839
840 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
841 // Casting from unranked type to ranked.
842 // The operation is assumed to be doing a correct cast. If the destination
843 // type mismatches the unranked the type, it is undefined behavior.
844 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
845 // ptr = ExtractValueOp src, 1
846 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
847 // castPtr = BitCastOp i8* to structTy*
848 auto castPtr =
849 rewriter
850 .create<LLVM::BitcastOp>(
851 loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
852 .getResult();
853 // struct = LoadOp castPtr
854 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
855 rewriter.replaceOp(memRefCastOp, loadOp.getResult());
856 } else {
857 llvm_unreachable("Unsupported unranked memref to unranked memref cast");
858 }
859 }
860 };
861
862 /// Pattern to lower a `memref.copy` to llvm.
863 ///
864 /// For memrefs with identity layouts, the copy is lowered to the llvm
865 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
866 /// to the generic `MemrefCopyFn`.
867 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
868 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
869
870 LogicalResult
lowerToMemCopyIntrinsic__anon7a9e10510111::MemRefCopyOpLowering871 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
872 ConversionPatternRewriter &rewriter) const {
873 auto loc = op.getLoc();
874 auto srcType = op.getSource().getType().dyn_cast<MemRefType>();
875
876 MemRefDescriptor srcDesc(adaptor.getSource());
877
878 // Compute number of elements.
879 Value numElements = rewriter.create<LLVM::ConstantOp>(
880 loc, getIndexType(), rewriter.getIndexAttr(1));
881 for (int pos = 0; pos < srcType.getRank(); ++pos) {
882 auto size = srcDesc.size(rewriter, loc, pos);
883 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
884 }
885
886 // Get element size.
887 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
888 // Compute total.
889 Value totalSize =
890 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
891
892 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
893 Value srcOffset = srcDesc.offset(rewriter, loc);
894 Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
895 srcBasePtr, srcOffset);
896 MemRefDescriptor targetDesc(adaptor.getTarget());
897 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
898 Value targetOffset = targetDesc.offset(rewriter, loc);
899 Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
900 targetBasePtr, targetOffset);
901 Value isVolatile = rewriter.create<LLVM::ConstantOp>(
902 loc, typeConverter->convertType(rewriter.getI1Type()),
903 rewriter.getBoolAttr(false));
904 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
905 isVolatile);
906 rewriter.eraseOp(op);
907
908 return success();
909 }
910
911 LogicalResult
lowerToMemCopyFunctionCall__anon7a9e10510111::MemRefCopyOpLowering912 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
913 ConversionPatternRewriter &rewriter) const {
914 auto loc = op.getLoc();
915 auto srcType = op.getSource().getType().cast<BaseMemRefType>();
916 auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
917
918 // First make sure we have an unranked memref descriptor representation.
919 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
920 auto rank = rewriter.create<LLVM::ConstantOp>(
921 loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
922 auto *typeConverter = getTypeConverter();
923 auto ptr =
924 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
925 auto voidPtr =
926 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
927 .getResult();
928 auto unrankedType =
929 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
930 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
931 unrankedType,
932 ValueRange{rank, voidPtr});
933 };
934
935 Value unrankedSource = srcType.hasRank()
936 ? makeUnranked(adaptor.getSource(), srcType)
937 : adaptor.getSource();
938 Value unrankedTarget = targetType.hasRank()
939 ? makeUnranked(adaptor.getTarget(), targetType)
940 : adaptor.getTarget();
941
942 // Now promote the unranked descriptors to the stack.
943 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
944 rewriter.getIndexAttr(1));
945 auto promote = [&](Value desc) {
946 auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
947 auto allocated =
948 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
949 rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
950 return allocated;
951 };
952
953 auto sourcePtr = promote(unrankedSource);
954 auto targetPtr = promote(unrankedTarget);
955
956 unsigned typeSize =
957 mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
958 auto elemSize = rewriter.create<LLVM::ConstantOp>(
959 loc, getIndexType(), rewriter.getIndexAttr(typeSize));
960 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
961 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
962 rewriter.create<LLVM::CallOp>(loc, copyFn,
963 ValueRange{elemSize, sourcePtr, targetPtr});
964 rewriter.eraseOp(op);
965
966 return success();
967 }
968
969 LogicalResult
matchAndRewrite__anon7a9e10510111::MemRefCopyOpLowering970 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
971 ConversionPatternRewriter &rewriter) const override {
972 auto srcType = op.getSource().getType().cast<BaseMemRefType>();
973 auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
974
975 auto isContiguousMemrefType = [](BaseMemRefType type) {
976 auto memrefType = type.dyn_cast<mlir::MemRefType>();
977 // We can use memcpy for memrefs if they have an identity layout or are
978 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
979 // special case handled by memrefCopy.
980 return memrefType &&
981 (memrefType.getLayout().isIdentity() ||
982 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
983 isStaticShapeAndContiguousRowMajor(memrefType)));
984 };
985
986 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
987 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
988
989 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
990 }
991 };
992
993 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
994 /// memref type. In unranked case, the fields are extracted from the underlying
995 /// ranked descriptor.
extractPointersAndOffset(Location loc,ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Value originalOperand,Value convertedOperand,Value * allocatedPtr,Value * alignedPtr,Value * offset=nullptr)996 static void extractPointersAndOffset(Location loc,
997 ConversionPatternRewriter &rewriter,
998 LLVMTypeConverter &typeConverter,
999 Value originalOperand,
1000 Value convertedOperand,
1001 Value *allocatedPtr, Value *alignedPtr,
1002 Value *offset = nullptr) {
1003 Type operandType = originalOperand.getType();
1004 if (operandType.isa<MemRefType>()) {
1005 MemRefDescriptor desc(convertedOperand);
1006 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1007 *alignedPtr = desc.alignedPtr(rewriter, loc);
1008 if (offset != nullptr)
1009 *offset = desc.offset(rewriter, loc);
1010 return;
1011 }
1012
1013 unsigned memorySpace =
1014 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
1015 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
1016 Type llvmElementType = typeConverter.convertType(elementType);
1017 Type elementPtrPtrType = LLVM::LLVMPointerType::get(
1018 LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
1019
1020 // Extract pointer to the underlying ranked memref descriptor and cast it to
1021 // ElemType**.
1022 UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1023 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1024
1025 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
1026 rewriter, loc, underlyingDescPtr, elementPtrPtrType);
1027 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1028 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1029 if (offset != nullptr) {
1030 *offset = UnrankedMemRefDescriptor::offset(
1031 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1032 }
1033 }
1034
1035 struct MemRefReinterpretCastOpLowering
1036 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1037 using ConvertOpToLLVMPattern<
1038 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1039
1040 LogicalResult
matchAndRewrite__anon7a9e10510111::MemRefReinterpretCastOpLowering1041 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1042 ConversionPatternRewriter &rewriter) const override {
1043 Type srcType = castOp.getSource().getType();
1044
1045 Value descriptor;
1046 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1047 adaptor, &descriptor)))
1048 return failure();
1049 rewriter.replaceOp(castOp, {descriptor});
1050 return success();
1051 }
1052
1053 private:
convertSourceMemRefToDescriptor__anon7a9e10510111::MemRefReinterpretCastOpLowering1054 LogicalResult convertSourceMemRefToDescriptor(
1055 ConversionPatternRewriter &rewriter, Type srcType,
1056 memref::ReinterpretCastOp castOp,
1057 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1058 MemRefType targetMemRefType =
1059 castOp.getResult().getType().cast<MemRefType>();
1060 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1061 .dyn_cast_or_null<LLVM::LLVMStructType>();
1062 if (!llvmTargetDescriptorTy)
1063 return failure();
1064
1065 // Create descriptor.
1066 Location loc = castOp.getLoc();
1067 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1068
1069 // Set allocated and aligned pointers.
1070 Value allocatedPtr, alignedPtr;
1071 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1072 castOp.getSource(), adaptor.getSource(),
1073 &allocatedPtr, &alignedPtr);
1074 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1075 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1076
1077 // Set offset.
1078 if (castOp.isDynamicOffset(0))
1079 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1080 else
1081 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1082
1083 // Set sizes and strides.
1084 unsigned dynSizeId = 0;
1085 unsigned dynStrideId = 0;
1086 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1087 if (castOp.isDynamicSize(i))
1088 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1089 else
1090 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1091
1092 if (castOp.isDynamicStride(i))
1093 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1094 else
1095 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1096 }
1097 *descriptor = desc;
1098 return success();
1099 }
1100 };
1101
1102 struct MemRefReshapeOpLowering
1103 : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1104 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1105
1106 LogicalResult
matchAndRewrite__anon7a9e10510111::MemRefReshapeOpLowering1107 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1108 ConversionPatternRewriter &rewriter) const override {
1109 Type srcType = reshapeOp.getSource().getType();
1110
1111 Value descriptor;
1112 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1113 adaptor, &descriptor)))
1114 return failure();
1115 rewriter.replaceOp(reshapeOp, {descriptor});
1116 return success();
1117 }
1118
1119 private:
1120 LogicalResult
convertSourceMemRefToDescriptor__anon7a9e10510111::MemRefReshapeOpLowering1121 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1122 Type srcType, memref::ReshapeOp reshapeOp,
1123 memref::ReshapeOp::Adaptor adaptor,
1124 Value *descriptor) const {
1125 auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>();
1126 if (shapeMemRefType.hasStaticShape()) {
1127 MemRefType targetMemRefType =
1128 reshapeOp.getResult().getType().cast<MemRefType>();
1129 auto llvmTargetDescriptorTy =
1130 typeConverter->convertType(targetMemRefType)
1131 .dyn_cast_or_null<LLVM::LLVMStructType>();
1132 if (!llvmTargetDescriptorTy)
1133 return failure();
1134
1135 // Create descriptor.
1136 Location loc = reshapeOp.getLoc();
1137 auto desc =
1138 MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1139
1140 // Set allocated and aligned pointers.
1141 Value allocatedPtr, alignedPtr;
1142 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1143 reshapeOp.getSource(), adaptor.getSource(),
1144 &allocatedPtr, &alignedPtr);
1145 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1146 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1147
1148 // Extract the offset and strides from the type.
1149 int64_t offset;
1150 SmallVector<int64_t> strides;
1151 if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1152 return rewriter.notifyMatchFailure(
1153 reshapeOp, "failed to get stride and offset exprs");
1154
1155 if (!isStaticStrideOrOffset(offset))
1156 return rewriter.notifyMatchFailure(reshapeOp,
1157 "dynamic offset is unsupported");
1158
1159 desc.setConstantOffset(rewriter, loc, offset);
1160
1161 assert(targetMemRefType.getLayout().isIdentity() &&
1162 "Identity layout map is a precondition of a valid reshape op");
1163
1164 Value stride = nullptr;
1165 int64_t targetRank = targetMemRefType.getRank();
1166 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1167 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1168 // If the stride for this dimension is dynamic, then use the product
1169 // of the sizes of the inner dimensions.
1170 stride = createIndexConstant(rewriter, loc, strides[i]);
1171 } else if (!stride) {
1172 // `stride` is null only in the first iteration of the loop. However,
1173 // since the target memref has an identity layout, we can safely set
1174 // the innermost stride to 1.
1175 stride = createIndexConstant(rewriter, loc, 1);
1176 }
1177
1178 Value dimSize;
1179 int64_t size = targetMemRefType.getDimSize(i);
1180 // If the size of this dimension is dynamic, then load it at runtime
1181 // from the shape operand.
1182 if (!ShapedType::isDynamic(size)) {
1183 dimSize = createIndexConstant(rewriter, loc, size);
1184 } else {
1185 Value shapeOp = reshapeOp.getShape();
1186 Value index = createIndexConstant(rewriter, loc, i);
1187 dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1188 Type indexType = getIndexType();
1189 if (dimSize.getType() != indexType)
1190 dimSize = typeConverter->materializeTargetConversion(
1191 rewriter, loc, indexType, dimSize);
1192 assert(dimSize && "Invalid memref element type");
1193 }
1194
1195 desc.setSize(rewriter, loc, i, dimSize);
1196 desc.setStride(rewriter, loc, i, stride);
1197
1198 // Prepare the stride value for the next dimension.
1199 stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1200 }
1201
1202 *descriptor = desc;
1203 return success();
1204 }
1205
1206 // The shape is a rank-1 tensor with unknown length.
1207 Location loc = reshapeOp.getLoc();
1208 MemRefDescriptor shapeDesc(adaptor.getShape());
1209 Value resultRank = shapeDesc.size(rewriter, loc, 0);
1210
1211 // Extract address space and element type.
1212 auto targetType =
1213 reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1214 unsigned addressSpace = targetType.getMemorySpaceAsInt();
1215 Type elementType = targetType.getElementType();
1216
1217 // Create the unranked memref descriptor that holds the ranked one. The
1218 // inner descriptor is allocated on stack.
1219 auto targetDesc = UnrankedMemRefDescriptor::undef(
1220 rewriter, loc, typeConverter->convertType(targetType));
1221 targetDesc.setRank(rewriter, loc, resultRank);
1222 SmallVector<Value, 4> sizes;
1223 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1224 targetDesc, sizes);
1225 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1226 loc, getVoidPtrType(), sizes.front(), llvm::None);
1227 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1228
1229 // Extract pointers and offset from the source memref.
1230 Value allocatedPtr, alignedPtr, offset;
1231 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1232 reshapeOp.getSource(), adaptor.getSource(),
1233 &allocatedPtr, &alignedPtr, &offset);
1234
1235 // Set pointers and offset.
1236 Type llvmElementType = typeConverter->convertType(elementType);
1237 auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1238 LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1239 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1240 elementPtrPtrType, allocatedPtr);
1241 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1242 underlyingDescPtr,
1243 elementPtrPtrType, alignedPtr);
1244 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1245 underlyingDescPtr, elementPtrPtrType,
1246 offset);
1247
1248 // Use the offset pointer as base for further addressing. Copy over the new
1249 // shape and compute strides. For this, we create a loop from rank-1 to 0.
1250 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1251 rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1252 elementPtrPtrType);
1253 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1254 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1255 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1256 Value oneIndex = createIndexConstant(rewriter, loc, 1);
1257 Value resultRankMinusOne =
1258 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1259
1260 Block *initBlock = rewriter.getInsertionBlock();
1261 Type indexType = getTypeConverter()->getIndexType();
1262 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1263
1264 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1265 {indexType, indexType}, {loc, loc});
1266
1267 // Move the remaining initBlock ops to condBlock.
1268 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1269 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1270
1271 rewriter.setInsertionPointToEnd(initBlock);
1272 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1273 condBlock);
1274 rewriter.setInsertionPointToStart(condBlock);
1275 Value indexArg = condBlock->getArgument(0);
1276 Value strideArg = condBlock->getArgument(1);
1277
1278 Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1279 Value pred = rewriter.create<LLVM::ICmpOp>(
1280 loc, IntegerType::get(rewriter.getContext(), 1),
1281 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1282
1283 Block *bodyBlock =
1284 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1285 rewriter.setInsertionPointToStart(bodyBlock);
1286
1287 // Copy size from shape to descriptor.
1288 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1289 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1290 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1291 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1292 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1293 targetSizesBase, indexArg, size);
1294
1295 // Write stride value and compute next one.
1296 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1297 targetStridesBase, indexArg, strideArg);
1298 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1299
1300 // Decrement loop counter and branch back.
1301 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1302 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1303 condBlock);
1304
1305 Block *remainder =
1306 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1307
1308 // Hook up the cond exit to the remainder.
1309 rewriter.setInsertionPointToEnd(condBlock);
1310 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1311 llvm::None);
1312
1313 // Reset position to beginning of new remainder block.
1314 rewriter.setInsertionPointToStart(remainder);
1315
1316 *descriptor = targetDesc;
1317 return success();
1318 }
1319 };
1320
1321 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1322 /// `Value`s.
getAsValues(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<OpFoldResult> valueOrAttrVec)1323 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1324 Type &llvmIndexType,
1325 ArrayRef<OpFoldResult> valueOrAttrVec) {
1326 return llvm::to_vector<4>(
1327 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1328 if (auto attr = value.dyn_cast<Attribute>())
1329 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1330 return value.get<Value>();
1331 }));
1332 }
1333
1334 /// Compute a map that for a given dimension of the expanded type gives the
1335 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1336 /// the `reassocation` maps.
1337 static DenseMap<int64_t, int64_t>
getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation)1338 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1339 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1340 for (auto &en : enumerate(reassociation)) {
1341 for (auto dim : en.value())
1342 expandedDimToCollapsedDim[dim] = en.index();
1343 }
1344 return expandedDimToCollapsedDim;
1345 }
1346
1347 static OpFoldResult
getExpandedOutputDimSize(OpBuilder & b,Location loc,Type & llvmIndexType,int64_t outDimIndex,ArrayRef<int64_t> outStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> inStaticShape,ArrayRef<ReassociationIndices> reassocation,DenseMap<int64_t,int64_t> & outDimToInDimMap)1348 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1349 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1350 MemRefDescriptor &inDesc,
1351 ArrayRef<int64_t> inStaticShape,
1352 ArrayRef<ReassociationIndices> reassocation,
1353 DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1354 int64_t outDimSize = outStaticShape[outDimIndex];
1355 if (!ShapedType::isDynamic(outDimSize))
1356 return b.getIndexAttr(outDimSize);
1357
1358 // Calculate the multiplication of all the out dim sizes except the
1359 // current dim.
1360 int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1361 int64_t otherDimSizesMul = 1;
1362 for (auto otherDimIndex : reassocation[inDimIndex]) {
1363 if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1364 continue;
1365 int64_t otherDimSize = outStaticShape[otherDimIndex];
1366 assert(!ShapedType::isDynamic(otherDimSize) &&
1367 "single dimension cannot be expanded into multiple dynamic "
1368 "dimensions");
1369 otherDimSizesMul *= otherDimSize;
1370 }
1371
1372 // outDimSize = inDimSize / otherOutDimSizesMul
1373 int64_t inDimSize = inStaticShape[inDimIndex];
1374 Value inDimSizeDynamic =
1375 ShapedType::isDynamic(inDimSize)
1376 ? inDesc.size(b, loc, inDimIndex)
1377 : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1378 b.getIndexAttr(inDimSize));
1379 Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1380 loc, inDimSizeDynamic,
1381 b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1382 b.getIndexAttr(otherDimSizesMul)));
1383 return outDimSizeDynamic;
1384 }
1385
getCollapsedOutputDimSize(OpBuilder & b,Location loc,Type & llvmIndexType,int64_t outDimIndex,int64_t outDimSize,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<ReassociationIndices> reassocation)1386 static OpFoldResult getCollapsedOutputDimSize(
1387 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1388 int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1389 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1390 if (!ShapedType::isDynamic(outDimSize))
1391 return b.getIndexAttr(outDimSize);
1392
1393 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1394 Value outDimSizeDynamic = c1;
1395 for (auto inDimIndex : reassocation[outDimIndex]) {
1396 int64_t inDimSize = inStaticShape[inDimIndex];
1397 Value inDimSizeDynamic =
1398 ShapedType::isDynamic(inDimSize)
1399 ? inDesc.size(b, loc, inDimIndex)
1400 : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1401 b.getIndexAttr(inDimSize));
1402 outDimSizeDynamic =
1403 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1404 }
1405 return outDimSizeDynamic;
1406 }
1407
1408 static SmallVector<OpFoldResult, 4>
getCollapsedOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassociation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1409 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1410 ArrayRef<ReassociationIndices> reassociation,
1411 ArrayRef<int64_t> inStaticShape,
1412 MemRefDescriptor &inDesc,
1413 ArrayRef<int64_t> outStaticShape) {
1414 return llvm::to_vector<4>(llvm::map_range(
1415 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1416 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1417 outStaticShape[outDimIndex],
1418 inStaticShape, inDesc, reassociation);
1419 }));
1420 }
1421
1422 static SmallVector<OpFoldResult, 4>
getExpandedOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassociation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1423 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1424 ArrayRef<ReassociationIndices> reassociation,
1425 ArrayRef<int64_t> inStaticShape,
1426 MemRefDescriptor &inDesc,
1427 ArrayRef<int64_t> outStaticShape) {
1428 DenseMap<int64_t, int64_t> outDimToInDimMap =
1429 getExpandedDimToCollapsedDimMap(reassociation);
1430 return llvm::to_vector<4>(llvm::map_range(
1431 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1432 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1433 outStaticShape, inDesc, inStaticShape,
1434 reassociation, outDimToInDimMap);
1435 }));
1436 }
1437
1438 static SmallVector<Value>
getDynamicOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassociation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1439 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1440 ArrayRef<ReassociationIndices> reassociation,
1441 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1442 ArrayRef<int64_t> outStaticShape) {
1443 return outStaticShape.size() < inStaticShape.size()
1444 ? getAsValues(b, loc, llvmIndexType,
1445 getCollapsedOutputShape(b, loc, llvmIndexType,
1446 reassociation, inStaticShape,
1447 inDesc, outStaticShape))
1448 : getAsValues(b, loc, llvmIndexType,
1449 getExpandedOutputShape(b, loc, llvmIndexType,
1450 reassociation, inStaticShape,
1451 inDesc, outStaticShape));
1452 }
1453
fillInStridesForExpandedMemDescriptor(OpBuilder & b,Location loc,MemRefType srcType,MemRefDescriptor & srcDesc,MemRefDescriptor & dstDesc,ArrayRef<ReassociationIndices> reassociation)1454 static void fillInStridesForExpandedMemDescriptor(
1455 OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
1456 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1457 // See comments for computeExpandedLayoutMap for details on how the strides
1458 // are calculated.
1459 for (auto &en : llvm::enumerate(reassociation)) {
1460 auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
1461 for (auto dstIndex : llvm::reverse(en.value())) {
1462 dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
1463 Value size = dstDesc.size(b, loc, dstIndex);
1464 currentStrideToExpand =
1465 b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
1466 }
1467 }
1468 }
1469
fillInStridesForCollapsedMemDescriptor(ConversionPatternRewriter & rewriter,Location loc,Operation * op,TypeConverter * typeConverter,MemRefType srcType,MemRefDescriptor & srcDesc,MemRefDescriptor & dstDesc,ArrayRef<ReassociationIndices> reassociation)1470 static void fillInStridesForCollapsedMemDescriptor(
1471 ConversionPatternRewriter &rewriter, Location loc, Operation *op,
1472 TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
1473 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1474 // See comments for computeCollapsedLayoutMap for details on how the strides
1475 // are calculated.
1476 auto srcShape = srcType.getShape();
1477 for (auto &en : llvm::enumerate(reassociation)) {
1478 rewriter.setInsertionPoint(op);
1479 auto dstIndex = en.index();
1480 ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value());
1481 while (srcShape[ref.back()] == 1 && ref.size() > 1)
1482 ref = ref.drop_back();
1483 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1484 dstDesc.setStride(rewriter, loc, dstIndex,
1485 srcDesc.stride(rewriter, loc, ref.back()));
1486 } else {
1487 // Iterate over the source strides in reverse order. Skip over the
1488 // dimensions whose size is 1.
1489 // TODO: we should take the minimum stride in the reassociation group
1490 // instead of just the first where the dimension is not 1.
1491 //
1492 // +------------------------------------------------------+
1493 // | curEntry: |
1494 // | %srcStride = strides[srcIndex] |
1495 // | %neOne = cmp sizes[srcIndex],1 +--+
1496 // | cf.cond_br %neOne, continue(%srcStride), nextEntry | |
1497 // +-------------------------+----------------------------+ |
1498 // | |
1499 // v |
1500 // +-----------------------------+ |
1501 // | nextEntry: | |
1502 // | ... +---+ |
1503 // +--------------+--------------+ | |
1504 // | | |
1505 // v | |
1506 // +-----------------------------+ | |
1507 // | nextEntry: | | |
1508 // | ... | | |
1509 // +--------------+--------------+ | +--------+
1510 // | | |
1511 // v v v
1512 // +--------------------------------------------------+
1513 // | continue(%newStride): |
1514 // | %newMemRefDes = setStride(%newStride,dstIndex) |
1515 // +--------------------------------------------------+
1516 OpBuilder::InsertionGuard guard(rewriter);
1517 Block *initBlock = rewriter.getInsertionBlock();
1518 Block *continueBlock =
1519 rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
1520 continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
1521 rewriter.setInsertionPointToStart(continueBlock);
1522 dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
1523
1524 Block *curEntryBlock = initBlock;
1525 Block *nextEntryBlock;
1526 for (auto srcIndex : llvm::reverse(ref)) {
1527 if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
1528 continue;
1529 rewriter.setInsertionPointToEnd(curEntryBlock);
1530 Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
1531 if (srcIndex == ref.front()) {
1532 rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
1533 break;
1534 }
1535 Value one = rewriter.create<LLVM::ConstantOp>(
1536 loc, typeConverter->convertType(rewriter.getI64Type()),
1537 rewriter.getI32IntegerAttr(1));
1538 Value predNeOne = rewriter.create<LLVM::ICmpOp>(
1539 loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
1540 one);
1541 {
1542 OpBuilder::InsertionGuard guard(rewriter);
1543 nextEntryBlock = rewriter.createBlock(
1544 initBlock->getParent(), Region::iterator(continueBlock), {});
1545 }
1546 rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
1547 srcStride, nextEntryBlock, llvm::None);
1548 curEntryBlock = nextEntryBlock;
1549 }
1550 }
1551 }
1552 }
1553
fillInDynamicStridesForMemDescriptor(ConversionPatternRewriter & b,Location loc,Operation * op,TypeConverter * typeConverter,MemRefType srcType,MemRefType dstType,MemRefDescriptor & srcDesc,MemRefDescriptor & dstDesc,ArrayRef<ReassociationIndices> reassociation)1554 static void fillInDynamicStridesForMemDescriptor(
1555 ConversionPatternRewriter &b, Location loc, Operation *op,
1556 TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
1557 MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
1558 ArrayRef<ReassociationIndices> reassociation) {
1559 if (srcType.getRank() > dstType.getRank())
1560 fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
1561 srcDesc, dstDesc, reassociation);
1562 else
1563 fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
1564 reassociation);
1565 }
1566
1567 // ReshapeOp creates a new view descriptor of the proper rank.
1568 // For now, the only conversion supported is for target MemRef with static sizes
1569 // and strides.
1570 template <typename ReshapeOp>
1571 class ReassociatingReshapeOpConversion
1572 : public ConvertOpToLLVMPattern<ReshapeOp> {
1573 public:
1574 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1575 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1576
1577 LogicalResult
matchAndRewrite(ReshapeOp reshapeOp,typename ReshapeOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const1578 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1579 ConversionPatternRewriter &rewriter) const override {
1580 MemRefType dstType = reshapeOp.getResultType();
1581 MemRefType srcType = reshapeOp.getSrcType();
1582
1583 int64_t offset;
1584 SmallVector<int64_t, 4> strides;
1585 if (failed(getStridesAndOffset(dstType, strides, offset))) {
1586 return rewriter.notifyMatchFailure(
1587 reshapeOp, "failed to get stride and offset exprs");
1588 }
1589
1590 MemRefDescriptor srcDesc(adaptor.getSrc());
1591 Location loc = reshapeOp->getLoc();
1592 auto dstDesc = MemRefDescriptor::undef(
1593 rewriter, loc, this->typeConverter->convertType(dstType));
1594 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1595 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1596 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1597
1598 ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1599 ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1600 Type llvmIndexType =
1601 this->typeConverter->convertType(rewriter.getIndexType());
1602 SmallVector<Value> dstShape = getDynamicOutputShape(
1603 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1604 srcStaticShape, srcDesc, dstStaticShape);
1605 for (auto &en : llvm::enumerate(dstShape))
1606 dstDesc.setSize(rewriter, loc, en.index(), en.value());
1607
1608 if (llvm::all_of(strides, isStaticStrideOrOffset)) {
1609 for (auto &en : llvm::enumerate(strides))
1610 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1611 } else if (srcType.getLayout().isIdentity() &&
1612 dstType.getLayout().isIdentity()) {
1613 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1614 rewriter.getIndexAttr(1));
1615 Value stride = c1;
1616 for (auto dimIndex :
1617 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1618 dstDesc.setStride(rewriter, loc, dimIndex, stride);
1619 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1620 }
1621 } else {
1622 // There could be mixed static/dynamic strides. For simplicity, we
1623 // recompute all strides if there is at least one dynamic stride.
1624 fillInDynamicStridesForMemDescriptor(
1625 rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
1626 srcDesc, dstDesc, reshapeOp.getReassociationIndices());
1627 }
1628 rewriter.replaceOp(reshapeOp, {dstDesc});
1629 return success();
1630 }
1631 };
1632
1633 /// Conversion pattern that transforms a subview op into:
1634 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1635 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1636 /// and stride.
1637 /// The subview op is replaced by the descriptor.
1638 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1639 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1640
1641 LogicalResult
matchAndRewrite__anon7a9e10510111::SubViewOpLowering1642 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1643 ConversionPatternRewriter &rewriter) const override {
1644 auto loc = subViewOp.getLoc();
1645
1646 auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>();
1647 auto sourceElementTy =
1648 typeConverter->convertType(sourceMemRefType.getElementType());
1649
1650 auto viewMemRefType = subViewOp.getType();
1651 auto inferredType =
1652 memref::SubViewOp::inferResultType(
1653 subViewOp.getSourceType(),
1654 extractFromI64ArrayAttr(subViewOp.getStaticOffsets()),
1655 extractFromI64ArrayAttr(subViewOp.getStaticSizes()),
1656 extractFromI64ArrayAttr(subViewOp.getStaticStrides()))
1657 .cast<MemRefType>();
1658 auto targetElementTy =
1659 typeConverter->convertType(viewMemRefType.getElementType());
1660 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1661 if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1662 !LLVM::isCompatibleType(sourceElementTy) ||
1663 !LLVM::isCompatibleType(targetElementTy) ||
1664 !LLVM::isCompatibleType(targetDescTy))
1665 return failure();
1666
1667 // Extract the offset and strides from the type.
1668 int64_t offset;
1669 SmallVector<int64_t, 4> strides;
1670 auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1671 if (failed(successStrides))
1672 return failure();
1673
1674 // Create the descriptor.
1675 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1676 return failure();
1677 MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1678 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1679
1680 // Copy the buffer pointer from the old descriptor to the new one.
1681 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1682 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1683 loc,
1684 LLVM::LLVMPointerType::get(targetElementTy,
1685 viewMemRefType.getMemorySpaceAsInt()),
1686 extracted);
1687 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1688
1689 // Copy the aligned pointer from the old descriptor to the new one.
1690 extracted = sourceMemRef.alignedPtr(rewriter, loc);
1691 bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1692 loc,
1693 LLVM::LLVMPointerType::get(targetElementTy,
1694 viewMemRefType.getMemorySpaceAsInt()),
1695 extracted);
1696 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1697
1698 size_t inferredShapeRank = inferredType.getRank();
1699 size_t resultShapeRank = viewMemRefType.getRank();
1700
1701 // Extract strides needed to compute offset.
1702 SmallVector<Value, 4> strideValues;
1703 strideValues.reserve(inferredShapeRank);
1704 for (unsigned i = 0; i < inferredShapeRank; ++i)
1705 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1706
1707 // Offset.
1708 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1709 if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1710 targetMemRef.setConstantOffset(rewriter, loc, offset);
1711 } else {
1712 Value baseOffset = sourceMemRef.offset(rewriter, loc);
1713 // `inferredShapeRank` may be larger than the number of offset operands
1714 // because of trailing semantics. In this case, the offset is guaranteed
1715 // to be interpreted as 0 and we can just skip the extra dimensions.
1716 for (unsigned i = 0, e = std::min(inferredShapeRank,
1717 subViewOp.getMixedOffsets().size());
1718 i < e; ++i) {
1719 Value offset =
1720 // TODO: need OpFoldResult ODS adaptor to clean this up.
1721 subViewOp.isDynamicOffset(i)
1722 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1723 : rewriter.create<LLVM::ConstantOp>(
1724 loc, llvmIndexType,
1725 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1726 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1727 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1728 }
1729 targetMemRef.setOffset(rewriter, loc, baseOffset);
1730 }
1731
1732 // Update sizes and strides.
1733 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1734 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1735 assert(mixedSizes.size() == mixedStrides.size() &&
1736 "expected sizes and strides of equal length");
1737 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1738 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1739 i >= 0 && j >= 0; --i) {
1740 if (unusedDims.test(i))
1741 continue;
1742
1743 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1744 // In this case, the size is guaranteed to be interpreted as Dim and the
1745 // stride as 1.
1746 Value size, stride;
1747 if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1748 // If the static size is available, use it directly. This is similar to
1749 // the folding of dim(constant-op) but removes the need for dim to be
1750 // aware of LLVM constants and for this pass to be aware of std
1751 // constants.
1752 int64_t staticSize =
1753 subViewOp.getSource().getType().cast<MemRefType>().getShape()[i];
1754 if (staticSize != ShapedType::kDynamicSize) {
1755 size = rewriter.create<LLVM::ConstantOp>(
1756 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1757 } else {
1758 Value pos = rewriter.create<LLVM::ConstantOp>(
1759 loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1760 Value dim =
1761 rewriter.create<memref::DimOp>(loc, subViewOp.getSource(), pos);
1762 auto cast = rewriter.create<UnrealizedConversionCastOp>(
1763 loc, llvmIndexType, dim);
1764 size = cast.getResult(0);
1765 }
1766 stride = rewriter.create<LLVM::ConstantOp>(
1767 loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1768 } else {
1769 // TODO: need OpFoldResult ODS adaptor to clean this up.
1770 size =
1771 subViewOp.isDynamicSize(i)
1772 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1773 : rewriter.create<LLVM::ConstantOp>(
1774 loc, llvmIndexType,
1775 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1776 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1777 stride = rewriter.create<LLVM::ConstantOp>(
1778 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1779 } else {
1780 stride =
1781 subViewOp.isDynamicStride(i)
1782 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1783 : rewriter.create<LLVM::ConstantOp>(
1784 loc, llvmIndexType,
1785 rewriter.getI64IntegerAttr(
1786 subViewOp.getStaticStride(i)));
1787 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1788 }
1789 }
1790 targetMemRef.setSize(rewriter, loc, j, size);
1791 targetMemRef.setStride(rewriter, loc, j, stride);
1792 j--;
1793 }
1794
1795 rewriter.replaceOp(subViewOp, {targetMemRef});
1796 return success();
1797 }
1798 };
1799
1800 /// Conversion pattern that transforms a transpose op into:
1801 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1802 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1803 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1804 /// and stride. Size and stride are permutations of the original values.
1805 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1806 /// The transpose op is replaced by the alloca'ed pointer.
1807 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1808 public:
1809 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1810
1811 LogicalResult
matchAndRewrite(memref::TransposeOp transposeOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1812 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1813 ConversionPatternRewriter &rewriter) const override {
1814 auto loc = transposeOp.getLoc();
1815 MemRefDescriptor viewMemRef(adaptor.getIn());
1816
1817 // No permutation, early exit.
1818 if (transposeOp.getPermutation().isIdentity())
1819 return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1820
1821 auto targetMemRef = MemRefDescriptor::undef(
1822 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1823
1824 // Copy the base and aligned pointers from the old descriptor to the new
1825 // one.
1826 targetMemRef.setAllocatedPtr(rewriter, loc,
1827 viewMemRef.allocatedPtr(rewriter, loc));
1828 targetMemRef.setAlignedPtr(rewriter, loc,
1829 viewMemRef.alignedPtr(rewriter, loc));
1830
1831 // Copy the offset pointer from the old descriptor to the new one.
1832 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1833
1834 // Iterate over the dimensions and apply size/stride permutation.
1835 for (const auto &en :
1836 llvm::enumerate(transposeOp.getPermutation().getResults())) {
1837 int sourcePos = en.index();
1838 int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1839 targetMemRef.setSize(rewriter, loc, targetPos,
1840 viewMemRef.size(rewriter, loc, sourcePos));
1841 targetMemRef.setStride(rewriter, loc, targetPos,
1842 viewMemRef.stride(rewriter, loc, sourcePos));
1843 }
1844
1845 rewriter.replaceOp(transposeOp, {targetMemRef});
1846 return success();
1847 }
1848 };
1849
1850 /// Conversion pattern that transforms an op into:
1851 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1852 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1853 /// and stride.
1854 /// The view op is replaced by the descriptor.
1855 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1856 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1857
1858 // Build and return the value for the idx^th shape dimension, either by
1859 // returning the constant shape dimension or counting the proper dynamic size.
getSize__anon7a9e10510111::ViewOpLowering1860 Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1861 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1862 unsigned idx) const {
1863 assert(idx < shape.size());
1864 if (!ShapedType::isDynamic(shape[idx]))
1865 return createIndexConstant(rewriter, loc, shape[idx]);
1866 // Count the number of dynamic dims in range [0, idx]
1867 unsigned nDynamic =
1868 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1869 return dynamicSizes[nDynamic];
1870 }
1871
1872 // Build and return the idx^th stride, either by returning the constant stride
1873 // or by computing the dynamic stride from the current `runningStride` and
1874 // `nextSize`. The caller should keep a running stride and update it with the
1875 // result returned by this function.
getStride__anon7a9e10510111::ViewOpLowering1876 Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1877 ArrayRef<int64_t> strides, Value nextSize,
1878 Value runningStride, unsigned idx) const {
1879 assert(idx < strides.size());
1880 if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1881 return createIndexConstant(rewriter, loc, strides[idx]);
1882 if (nextSize)
1883 return runningStride
1884 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1885 : nextSize;
1886 assert(!runningStride);
1887 return createIndexConstant(rewriter, loc, 1);
1888 }
1889
1890 LogicalResult
matchAndRewrite__anon7a9e10510111::ViewOpLowering1891 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1892 ConversionPatternRewriter &rewriter) const override {
1893 auto loc = viewOp.getLoc();
1894
1895 auto viewMemRefType = viewOp.getType();
1896 auto targetElementTy =
1897 typeConverter->convertType(viewMemRefType.getElementType());
1898 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1899 if (!targetDescTy || !targetElementTy ||
1900 !LLVM::isCompatibleType(targetElementTy) ||
1901 !LLVM::isCompatibleType(targetDescTy))
1902 return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1903 failure();
1904
1905 int64_t offset;
1906 SmallVector<int64_t, 4> strides;
1907 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1908 if (failed(successStrides))
1909 return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1910 assert(offset == 0 && "expected offset to be 0");
1911
1912 // Target memref must be contiguous in memory (innermost stride is 1), or
1913 // empty (special case when at least one of the memref dimensions is 0).
1914 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1915 return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1916 failure();
1917
1918 // Create the descriptor.
1919 MemRefDescriptor sourceMemRef(adaptor.getSource());
1920 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1921
1922 // Field 1: Copy the allocated pointer, used for malloc/free.
1923 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1924 auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
1925 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1926 loc,
1927 LLVM::LLVMPointerType::get(targetElementTy,
1928 srcMemRefType.getMemorySpaceAsInt()),
1929 allocatedPtr);
1930 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1931
1932 // Field 2: Copy the actual aligned pointer to payload.
1933 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1934 alignedPtr = rewriter.create<LLVM::GEPOp>(
1935 loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift());
1936 bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1937 loc,
1938 LLVM::LLVMPointerType::get(targetElementTy,
1939 srcMemRefType.getMemorySpaceAsInt()),
1940 alignedPtr);
1941 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1942
1943 // Field 3: The offset in the resulting type must be 0. This is because of
1944 // the type change: an offset on srcType* may not be expressible as an
1945 // offset on dstType*.
1946 targetMemRef.setOffset(rewriter, loc,
1947 createIndexConstant(rewriter, loc, offset));
1948
1949 // Early exit for 0-D corner case.
1950 if (viewMemRefType.getRank() == 0)
1951 return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1952
1953 // Fields 4 and 5: Update sizes and strides.
1954 Value stride = nullptr, nextSize = nullptr;
1955 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1956 // Update size.
1957 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1958 adaptor.getSizes(), i);
1959 targetMemRef.setSize(rewriter, loc, i, size);
1960 // Update stride.
1961 stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1962 targetMemRef.setStride(rewriter, loc, i, stride);
1963 nextSize = size;
1964 }
1965
1966 rewriter.replaceOp(viewOp, {targetMemRef});
1967 return success();
1968 }
1969 };
1970
1971 //===----------------------------------------------------------------------===//
1972 // AtomicRMWOpLowering
1973 //===----------------------------------------------------------------------===//
1974
1975 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1976 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1977 static Optional<LLVM::AtomicBinOp>
matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp)1978 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1979 switch (atomicOp.getKind()) {
1980 case arith::AtomicRMWKind::addf:
1981 return LLVM::AtomicBinOp::fadd;
1982 case arith::AtomicRMWKind::addi:
1983 return LLVM::AtomicBinOp::add;
1984 case arith::AtomicRMWKind::assign:
1985 return LLVM::AtomicBinOp::xchg;
1986 case arith::AtomicRMWKind::maxs:
1987 return LLVM::AtomicBinOp::max;
1988 case arith::AtomicRMWKind::maxu:
1989 return LLVM::AtomicBinOp::umax;
1990 case arith::AtomicRMWKind::mins:
1991 return LLVM::AtomicBinOp::min;
1992 case arith::AtomicRMWKind::minu:
1993 return LLVM::AtomicBinOp::umin;
1994 case arith::AtomicRMWKind::ori:
1995 return LLVM::AtomicBinOp::_or;
1996 case arith::AtomicRMWKind::andi:
1997 return LLVM::AtomicBinOp::_and;
1998 default:
1999 return llvm::None;
2000 }
2001 llvm_unreachable("Invalid AtomicRMWKind");
2002 }
2003
2004 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
2005 using Base::Base;
2006
2007 LogicalResult
matchAndRewrite__anon7a9e10510111::AtomicRMWOpLowering2008 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
2009 ConversionPatternRewriter &rewriter) const override {
2010 if (failed(match(atomicOp)))
2011 return failure();
2012 auto maybeKind = matchSimpleAtomicOp(atomicOp);
2013 if (!maybeKind)
2014 return failure();
2015 auto resultType = adaptor.getValue().getType();
2016 auto memRefType = atomicOp.getMemRefType();
2017 auto dataPtr =
2018 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
2019 adaptor.getIndices(), rewriter);
2020 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
2021 atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
2022 LLVM::AtomicOrdering::acq_rel);
2023 return success();
2024 }
2025 };
2026
2027 } // namespace
2028
populateMemRefToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)2029 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
2030 RewritePatternSet &patterns) {
2031 // clang-format off
2032 patterns.add<
2033 AllocaOpLowering,
2034 AllocaScopeOpLowering,
2035 AtomicRMWOpLowering,
2036 AssumeAlignmentOpLowering,
2037 DimOpLowering,
2038 GenericAtomicRMWOpLowering,
2039 GlobalMemrefOpLowering,
2040 GetGlobalMemrefOpLowering,
2041 LoadOpLowering,
2042 MemRefCastOpLowering,
2043 MemRefCopyOpLowering,
2044 MemRefReinterpretCastOpLowering,
2045 MemRefReshapeOpLowering,
2046 PrefetchOpLowering,
2047 RankOpLowering,
2048 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2049 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2050 StoreOpLowering,
2051 SubViewOpLowering,
2052 TransposeOpLowering,
2053 ViewOpLowering>(converter);
2054 // clang-format on
2055 auto allocLowering = converter.getOptions().allocLowering;
2056 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
2057 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
2058 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2059 patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
2060 }
2061
2062 namespace {
2063 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
2064 MemRefToLLVMPass() = default;
2065
runOnOperation__anon7a9e10510811::MemRefToLLVMPass2066 void runOnOperation() override {
2067 Operation *op = getOperation();
2068 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2069 LowerToLLVMOptions options(&getContext(),
2070 dataLayoutAnalysis.getAtOrAbove(op));
2071 options.allocLowering =
2072 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
2073 : LowerToLLVMOptions::AllocLowering::Malloc);
2074
2075 options.useGenericFunctions = useGenericFunctions;
2076
2077 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2078 options.overrideIndexBitwidth(indexBitwidth);
2079
2080 LLVMTypeConverter typeConverter(&getContext(), options,
2081 &dataLayoutAnalysis);
2082 RewritePatternSet patterns(&getContext());
2083 populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
2084 LLVMConversionTarget target(getContext());
2085 target.addLegalOp<func::FuncOp>();
2086 if (failed(applyPartialConversion(op, target, std::move(patterns))))
2087 signalPassFailure();
2088 }
2089 };
2090 } // namespace
2091
createMemRefToLLVMPass()2092 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
2093 return std::make_unique<MemRefToLLVMPass>();
2094 }
2095