1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 
29 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
30   return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
31 }
32 
33 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
34   AllocOpLowering(LLVMTypeConverter &converter)
35       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
36                                 converter) {}
37 
38   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
39                                           Location loc, Value sizeBytes,
40                                           Operation *op) const override {
41     // Heap allocations.
42     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
43     MemRefType memRefType = allocOp.getType();
44 
45     Value alignment;
46     if (auto alignmentAttr = allocOp.alignment()) {
47       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
48     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
49       // In the case where no alignment is specified, we may want to override
50       // `malloc's` behavior. `malloc` typically aligns at the size of the
51       // biggest scalar on a target HW. For non-scalars, use the natural
52       // alignment of the LLVM type given by the LLVM DataLayout.
53       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
54     }
55 
56     if (alignment) {
57       // Adjust the allocation size to consider alignment.
58       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
59     }
60 
61     // Allocate the underlying buffer and store a pointer to it in the MemRef
62     // descriptor.
63     Type elementPtrType = this->getElementPtrType(memRefType);
64     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
65         allocOp->getParentOfType<ModuleOp>(), getIndexType());
66     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
67                                   getVoidPtrType());
68     Value allocatedPtr =
69         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
70 
71     Value alignedPtr = allocatedPtr;
72     if (alignment) {
73       // Compute the aligned type pointer.
74       Value allocatedInt =
75           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
76       Value alignmentInt =
77           createAligned(rewriter, loc, allocatedInt, alignment);
78       alignedPtr =
79           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
80     }
81 
82     return std::make_tuple(allocatedPtr, alignedPtr);
83   }
84 };
85 
86 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
87   AlignedAllocOpLowering(LLVMTypeConverter &converter)
88       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
89                                 converter) {}
90 
91   /// Returns the memref's element size in bytes using the data layout active at
92   /// `op`.
93   // TODO: there are other places where this is used. Expose publicly?
94   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
95     const DataLayout *layout = &defaultLayout;
96     if (const DataLayoutAnalysis *analysis =
97             getTypeConverter()->getDataLayoutAnalysis()) {
98       layout = &analysis->getAbove(op);
99     }
100     Type elementType = memRefType.getElementType();
101     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
102       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
103                                                          *layout);
104     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
105       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
106           memRefElementType, *layout);
107     return layout->getTypeSize(elementType);
108   }
109 
110   /// Returns true if the memref size in bytes is known to be a multiple of
111   /// factor assuming the data layout active at `op`.
112   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
113                               Operation *op) const {
114     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
115     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
116       if (ShapedType::isDynamic(type.getDimSize(i)))
117         continue;
118       sizeDivisor = sizeDivisor * type.getDimSize(i);
119     }
120     return sizeDivisor % factor == 0;
121   }
122 
123   /// Returns the alignment to be used for the allocation call itself.
124   /// aligned_alloc requires the allocation size to be a power of two, and the
125   /// allocation size to be a multiple of alignment,
126   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
127     if (Optional<uint64_t> alignment = allocOp.alignment())
128       return *alignment;
129 
130     // Whenever we don't have alignment set, we will use an alignment
131     // consistent with the element type; since the allocation size has to be a
132     // power of two, we will bump to the next power of two if it already isn't.
133     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
134     return std::max(kMinAlignedAllocAlignment,
135                     llvm::PowerOf2Ceil(eltSizeBytes));
136   }
137 
138   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
139                                           Location loc, Value sizeBytes,
140                                           Operation *op) const override {
141     // Heap allocations.
142     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
143     MemRefType memRefType = allocOp.getType();
144     int64_t alignment = getAllocationAlignment(allocOp);
145     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
146 
147     // aligned_alloc requires size to be a multiple of alignment; we will pad
148     // the size to the next multiple if necessary.
149     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
150       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
151 
152     Type elementPtrType = this->getElementPtrType(memRefType);
153     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
154         allocOp->getParentOfType<ModuleOp>(), getIndexType());
155     auto results =
156         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
157                        getVoidPtrType());
158     Value allocatedPtr =
159         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
160 
161     return std::make_tuple(allocatedPtr, allocatedPtr);
162   }
163 
164   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
165   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
166 
167   /// Default layout to use in absence of the corresponding analysis.
168   DataLayout defaultLayout;
169 };
170 
171 // Out of line definition, required till C++17.
172 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
173 
174 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
175   AllocaOpLowering(LLVMTypeConverter &converter)
176       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
177                                 converter) {}
178 
179   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
180   /// is set to null for stack allocations. `accessAlignment` is set if
181   /// alignment is needed post allocation (for eg. in conjunction with malloc).
182   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
183                                           Location loc, Value sizeBytes,
184                                           Operation *op) const override {
185 
186     // With alloca, one gets a pointer to the element type right away.
187     // For stack allocations.
188     auto allocaOp = cast<memref::AllocaOp>(op);
189     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
190 
191     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
192         loc, elementPtrType, sizeBytes,
193         allocaOp.alignment() ? *allocaOp.alignment() : 0);
194 
195     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
196   }
197 };
198 
199 struct AllocaScopeOpLowering
200     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
201   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
202 
203   LogicalResult
204   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
205                   ConversionPatternRewriter &rewriter) const override {
206     OpBuilder::InsertionGuard guard(rewriter);
207     Location loc = allocaScopeOp.getLoc();
208 
209     // Split the current block before the AllocaScopeOp to create the inlining
210     // point.
211     auto *currentBlock = rewriter.getInsertionBlock();
212     auto *remainingOpsBlock =
213         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
214     Block *continueBlock;
215     if (allocaScopeOp.getNumResults() == 0) {
216       continueBlock = remainingOpsBlock;
217     } else {
218       continueBlock = rewriter.createBlock(
219           remainingOpsBlock, allocaScopeOp.getResultTypes(),
220           SmallVector<Location>(allocaScopeOp->getNumResults(),
221                                 allocaScopeOp.getLoc()));
222       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
223     }
224 
225     // Inline body region.
226     Block *beforeBody = &allocaScopeOp.bodyRegion().front();
227     Block *afterBody = &allocaScopeOp.bodyRegion().back();
228     rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
229 
230     // Save stack and then branch into the body of the region.
231     rewriter.setInsertionPointToEnd(currentBlock);
232     auto stackSaveOp =
233         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
234     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
235 
236     // Replace the alloca_scope return with a branch that jumps out of the body.
237     // Stack restore before leaving the body region.
238     rewriter.setInsertionPointToEnd(afterBody);
239     auto returnOp =
240         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
241     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
242         returnOp, returnOp.results(), continueBlock);
243 
244     // Insert stack restore before jumping out the body of the region.
245     rewriter.setInsertionPoint(branchOp);
246     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
247 
248     // Replace the op with values return from the body region.
249     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
250 
251     return success();
252   }
253 };
254 
255 struct AssumeAlignmentOpLowering
256     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
257   using ConvertOpToLLVMPattern<
258       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
259 
260   LogicalResult
261   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
262                   ConversionPatternRewriter &rewriter) const override {
263     Value memref = adaptor.memref();
264     unsigned alignment = op.alignment();
265     auto loc = op.getLoc();
266 
267     MemRefDescriptor memRefDescriptor(memref);
268     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
269 
270     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
271     // the asserted memref.alignedPtr isn't used anywhere else, as the real
272     // users like load/store/views always re-extract memref.alignedPtr as they
273     // get lowered.
274     //
275     // This relies on LLVM's CSE optimization (potentially after SROA), since
276     // after CSE all memref.alignedPtr instances get de-duplicated into the same
277     // pointer SSA value.
278     auto intPtrType =
279         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
280     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
281     Value mask =
282         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
283     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
284     rewriter.create<LLVM::AssumeOp>(
285         loc, rewriter.create<LLVM::ICmpOp>(
286                  loc, LLVM::ICmpPredicate::eq,
287                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
288 
289     rewriter.eraseOp(op);
290     return success();
291   }
292 };
293 
294 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
295 // The memref descriptor being an SSA value, there is no need to clean it up
296 // in any way.
297 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
298   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
299 
300   explicit DeallocOpLowering(LLVMTypeConverter &converter)
301       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
302 
303   LogicalResult
304   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
305                   ConversionPatternRewriter &rewriter) const override {
306     // Insert the `free` declaration if it is not already present.
307     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
308     MemRefDescriptor memref(adaptor.memref());
309     Value casted = rewriter.create<LLVM::BitcastOp>(
310         op.getLoc(), getVoidPtrType(),
311         memref.allocatedPtr(rewriter, op.getLoc()));
312     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
313         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
314     return success();
315   }
316 };
317 
318 // A `dim` is converted to a constant for static sizes and to an access to the
319 // size stored in the memref descriptor for dynamic sizes.
320 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
321   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
322 
323   LogicalResult
324   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
325                   ConversionPatternRewriter &rewriter) const override {
326     Type operandType = dimOp.source().getType();
327     if (operandType.isa<UnrankedMemRefType>()) {
328       rewriter.replaceOp(
329           dimOp, {extractSizeOfUnrankedMemRef(
330                      operandType, dimOp, adaptor.getOperands(), rewriter)});
331 
332       return success();
333     }
334     if (operandType.isa<MemRefType>()) {
335       rewriter.replaceOp(
336           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
337                                             adaptor.getOperands(), rewriter)});
338       return success();
339     }
340     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
341   }
342 
343 private:
344   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
345                                     OpAdaptor adaptor,
346                                     ConversionPatternRewriter &rewriter) const {
347     Location loc = dimOp.getLoc();
348 
349     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
350     auto scalarMemRefType =
351         MemRefType::get({}, unrankedMemRefType.getElementType());
352     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
353 
354     // Extract pointer to the underlying ranked descriptor and bitcast it to a
355     // memref<element_type> descriptor pointer to minimize the number of GEP
356     // operations.
357     UnrankedMemRefDescriptor unrankedDesc(adaptor.source());
358     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
359     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
360         loc,
361         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
362                                    addressSpace),
363         underlyingRankedDesc);
364 
365     // Get pointer to offset field of memref<element_type> descriptor.
366     Type indexPtrTy = LLVM::LLVMPointerType::get(
367         getTypeConverter()->getIndexType(), addressSpace);
368     Value two = rewriter.create<LLVM::ConstantOp>(
369         loc, typeConverter->convertType(rewriter.getI32Type()),
370         rewriter.getI32IntegerAttr(2));
371     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
372         loc, indexPtrTy, scalarMemRefDescPtr,
373         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
374 
375     // The size value that we have to extract can be obtained using GEPop with
376     // `dimOp.index() + 1` index argument.
377     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
378         loc, createIndexConstant(rewriter, loc, 1), adaptor.index());
379     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
380                                                  ValueRange({idxPlusOne}));
381     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
382   }
383 
384   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
385     if (Optional<int64_t> idx = dimOp.getConstantIndex())
386       return idx;
387 
388     if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>())
389       return constantOp.getValue()
390           .cast<IntegerAttr>()
391           .getValue()
392           .getSExtValue();
393 
394     return llvm::None;
395   }
396 
397   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
398                                   OpAdaptor adaptor,
399                                   ConversionPatternRewriter &rewriter) const {
400     Location loc = dimOp.getLoc();
401 
402     // Take advantage if index is constant.
403     MemRefType memRefType = operandType.cast<MemRefType>();
404     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
405       int64_t i = index.getValue();
406       if (memRefType.isDynamicDim(i)) {
407         // extract dynamic size from the memref descriptor.
408         MemRefDescriptor descriptor(adaptor.source());
409         return descriptor.size(rewriter, loc, i);
410       }
411       // Use constant for static size.
412       int64_t dimSize = memRefType.getDimSize(i);
413       return createIndexConstant(rewriter, loc, dimSize);
414     }
415     Value index = adaptor.index();
416     int64_t rank = memRefType.getRank();
417     MemRefDescriptor memrefDescriptor(adaptor.source());
418     return memrefDescriptor.size(rewriter, loc, index, rank);
419   }
420 };
421 
422 /// Common base for load and store operations on MemRefs. Restricts the match
423 /// to supported MemRef types. Provides functionality to emit code accessing a
424 /// specific element of the underlying data buffer.
425 template <typename Derived>
426 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
427   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
428   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
429   using Base = LoadStoreOpLowering<Derived>;
430 
431   LogicalResult match(Derived op) const override {
432     MemRefType type = op.getMemRefType();
433     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
434   }
435 };
436 
437 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
438 /// retried until it succeeds in atomically storing a new value into memory.
439 ///
440 ///      +---------------------------------+
441 ///      |   <code before the AtomicRMWOp> |
442 ///      |   <compute initial %loaded>     |
443 ///      |   cf.br loop(%loaded)              |
444 ///      +---------------------------------+
445 ///             |
446 ///  -------|   |
447 ///  |      v   v
448 ///  |   +--------------------------------+
449 ///  |   | loop(%loaded):                 |
450 ///  |   |   <body contents>              |
451 ///  |   |   %pair = cmpxchg              |
452 ///  |   |   %ok = %pair[0]               |
453 ///  |   |   %new = %pair[1]              |
454 ///  |   |   cf.cond_br %ok, end, loop(%new) |
455 ///  |   +--------------------------------+
456 ///  |          |        |
457 ///  |-----------        |
458 ///                      v
459 ///      +--------------------------------+
460 ///      | end:                           |
461 ///      |   <code after the AtomicRMWOp> |
462 ///      +--------------------------------+
463 ///
464 struct GenericAtomicRMWOpLowering
465     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
466   using Base::Base;
467 
468   LogicalResult
469   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
470                   ConversionPatternRewriter &rewriter) const override {
471     auto loc = atomicOp.getLoc();
472     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
473 
474     // Split the block into initial, loop, and ending parts.
475     auto *initBlock = rewriter.getInsertionBlock();
476     auto *loopBlock = rewriter.createBlock(
477         initBlock->getParent(), std::next(Region::iterator(initBlock)),
478         valueType, loc);
479     auto *endBlock = rewriter.createBlock(
480         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
481 
482     // Operations range to be moved to `endBlock`.
483     auto opsToMoveStart = atomicOp->getIterator();
484     auto opsToMoveEnd = initBlock->back().getIterator();
485 
486     // Compute the loaded value and branch to the loop block.
487     rewriter.setInsertionPointToEnd(initBlock);
488     auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
489     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
490                                         adaptor.indices(), rewriter);
491     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
492     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
493 
494     // Prepare the body of the loop block.
495     rewriter.setInsertionPointToStart(loopBlock);
496 
497     // Clone the GenericAtomicRMWOp region and extract the result.
498     auto loopArgument = loopBlock->getArgument(0);
499     BlockAndValueMapping mapping;
500     mapping.map(atomicOp.getCurrentValue(), loopArgument);
501     Block &entryBlock = atomicOp.body().front();
502     for (auto &nestedOp : entryBlock.without_terminator()) {
503       Operation *clone = rewriter.clone(nestedOp, mapping);
504       mapping.map(nestedOp.getResults(), clone->getResults());
505     }
506     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
507 
508     // Prepare the epilog of the loop block.
509     // Append the cmpxchg op to the end of the loop block.
510     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
511     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
512     auto boolType = IntegerType::get(rewriter.getContext(), 1);
513     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
514                                                      {valueType, boolType});
515     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
516         loc, pairType, dataPtr, loopArgument, result, successOrdering,
517         failureOrdering);
518     // Extract the %new_loaded and %ok values from the pair.
519     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
520         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
521     Value ok = rewriter.create<LLVM::ExtractValueOp>(
522         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
523 
524     // Conditionally branch to the end or back to the loop depending on %ok.
525     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
526                                     loopBlock, newLoaded);
527 
528     rewriter.setInsertionPointToEnd(endBlock);
529     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
530                  std::next(opsToMoveEnd), rewriter);
531 
532     // The 'result' of the atomic_rmw op is the newly loaded value.
533     rewriter.replaceOp(atomicOp, {newLoaded});
534 
535     return success();
536   }
537 
538 private:
539   // Clones a segment of ops [start, end) and erases the original.
540   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
541                     Block::iterator start, Block::iterator end,
542                     ConversionPatternRewriter &rewriter) const {
543     BlockAndValueMapping mapping;
544     mapping.map(oldResult, newResult);
545     SmallVector<Operation *, 2> opsToErase;
546     for (auto it = start; it != end; ++it) {
547       rewriter.clone(*it, mapping);
548       opsToErase.push_back(&*it);
549     }
550     for (auto *it : opsToErase)
551       rewriter.eraseOp(it);
552   }
553 };
554 
555 /// Returns the LLVM type of the global variable given the memref type `type`.
556 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
557                                           LLVMTypeConverter &typeConverter) {
558   // LLVM type for a global memref will be a multi-dimension array. For
559   // declarations or uninitialized global memrefs, we can potentially flatten
560   // this to a 1D array. However, for memref.global's with an initial value,
561   // we do not intend to flatten the ElementsAttribute when going from std ->
562   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
563   Type elementType = typeConverter.convertType(type.getElementType());
564   Type arrayTy = elementType;
565   // Shape has the outermost dim at index 0, so need to walk it backwards
566   for (int64_t dim : llvm::reverse(type.getShape()))
567     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
568   return arrayTy;
569 }
570 
571 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
572 struct GlobalMemrefOpLowering
573     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
574   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
575 
576   LogicalResult
577   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
578                   ConversionPatternRewriter &rewriter) const override {
579     MemRefType type = global.type();
580     if (!isConvertibleAndHasIdentityMaps(type))
581       return failure();
582 
583     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
584 
585     LLVM::Linkage linkage =
586         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
587 
588     Attribute initialValue = nullptr;
589     if (!global.isExternal() && !global.isUninitialized()) {
590       auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
591       initialValue = elementsAttr;
592 
593       // For scalar memrefs, the global variable created is of the element type,
594       // so unpack the elements attribute to extract the value.
595       if (type.getRank() == 0)
596         initialValue = elementsAttr.getSplatValue<Attribute>();
597     }
598 
599     uint64_t alignment = global.alignment().value_or(0);
600 
601     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
602         global, arrayTy, global.constant(), linkage, global.sym_name(),
603         initialValue, alignment, type.getMemorySpaceAsInt());
604     if (!global.isExternal() && global.isUninitialized()) {
605       Block *blk = new Block();
606       newGlobal.getInitializerRegion().push_back(blk);
607       rewriter.setInsertionPointToStart(blk);
608       Value undef[] = {
609           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
610       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
611     }
612     return success();
613   }
614 };
615 
616 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
617 /// the first element stashed into the descriptor. This reuses
618 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
619 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
620   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
621       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
622                                 converter) {}
623 
624   /// Buffer "allocation" for memref.get_global op is getting the address of
625   /// the global variable referenced.
626   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
627                                           Location loc, Value sizeBytes,
628                                           Operation *op) const override {
629     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
630     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
631     unsigned memSpace = type.getMemorySpaceAsInt();
632 
633     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
634     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
635         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
636 
637     // Get the address of the first element in the array by creating a GEP with
638     // the address of the GV as the base, and (rank + 1) number of 0 indices.
639     Type elementType = typeConverter->convertType(type.getElementType());
640     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
641 
642     SmallVector<Value> operands;
643     operands.insert(operands.end(), type.getRank() + 1,
644                     createIndexConstant(rewriter, loc, 0));
645     auto gep =
646         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
647 
648     // We do not expect the memref obtained using `memref.get_global` to be
649     // ever deallocated. Set the allocated pointer to be known bad value to
650     // help debug if that ever happens.
651     auto intPtrType = getIntPtrType(memSpace);
652     Value deadBeefConst =
653         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
654     auto deadBeefPtr =
655         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
656 
657     // Both allocated and aligned pointers are same. We could potentially stash
658     // a nullptr for the allocated pointer since we do not expect any dealloc.
659     return std::make_tuple(deadBeefPtr, gep);
660   }
661 };
662 
663 // Load operation is lowered to obtaining a pointer to the indexed element
664 // and loading it.
665 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
666   using Base::Base;
667 
668   LogicalResult
669   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
670                   ConversionPatternRewriter &rewriter) const override {
671     auto type = loadOp.getMemRefType();
672 
673     Value dataPtr = getStridedElementPtr(
674         loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter);
675     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
676     return success();
677   }
678 };
679 
680 // Store operation is lowered to obtaining a pointer to the indexed element,
681 // and storing the given value to it.
682 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
683   using Base::Base;
684 
685   LogicalResult
686   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
687                   ConversionPatternRewriter &rewriter) const override {
688     auto type = op.getMemRefType();
689 
690     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(),
691                                          adaptor.indices(), rewriter);
692     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr);
693     return success();
694   }
695 };
696 
697 // The prefetch operation is lowered in a way similar to the load operation
698 // except that the llvm.prefetch operation is used for replacement.
699 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
700   using Base::Base;
701 
702   LogicalResult
703   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
704                   ConversionPatternRewriter &rewriter) const override {
705     auto type = prefetchOp.getMemRefType();
706     auto loc = prefetchOp.getLoc();
707 
708     Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(),
709                                          adaptor.indices(), rewriter);
710 
711     // Replace with llvm.prefetch.
712     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
713     auto isWrite = rewriter.create<LLVM::ConstantOp>(
714         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
715     auto localityHint = rewriter.create<LLVM::ConstantOp>(
716         loc, llvmI32Type,
717         rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
718     auto isData = rewriter.create<LLVM::ConstantOp>(
719         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
720 
721     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
722                                                 localityHint, isData);
723     return success();
724   }
725 };
726 
727 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
728   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
729 
730   LogicalResult
731   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
732                   ConversionPatternRewriter &rewriter) const override {
733     Location loc = op.getLoc();
734     Type operandType = op.memref().getType();
735     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
736       UnrankedMemRefDescriptor desc(adaptor.memref());
737       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
738       return success();
739     }
740     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
741       rewriter.replaceOp(
742           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
743       return success();
744     }
745     return failure();
746   }
747 };
748 
749 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
750   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
751 
752   LogicalResult match(memref::CastOp memRefCastOp) const override {
753     Type srcType = memRefCastOp.getOperand().getType();
754     Type dstType = memRefCastOp.getType();
755 
756     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
757     // used for type erasure. For now they must preserve underlying element type
758     // and require source and result type to have the same rank. Therefore,
759     // perform a sanity check that the underlying structs are the same. Once op
760     // semantics are relaxed we can revisit.
761     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
762       return success(typeConverter->convertType(srcType) ==
763                      typeConverter->convertType(dstType));
764 
765     // At least one of the operands is unranked type
766     assert(srcType.isa<UnrankedMemRefType>() ||
767            dstType.isa<UnrankedMemRefType>());
768 
769     // Unranked to unranked cast is disallowed
770     return !(srcType.isa<UnrankedMemRefType>() &&
771              dstType.isa<UnrankedMemRefType>())
772                ? success()
773                : failure();
774   }
775 
776   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
777                ConversionPatternRewriter &rewriter) const override {
778     auto srcType = memRefCastOp.getOperand().getType();
779     auto dstType = memRefCastOp.getType();
780     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
781     auto loc = memRefCastOp.getLoc();
782 
783     // For ranked/ranked case, just keep the original descriptor.
784     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
785       return rewriter.replaceOp(memRefCastOp, {adaptor.source()});
786 
787     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
788       // Casting ranked to unranked memref type
789       // Set the rank in the destination from the memref type
790       // Allocate space on the stack and copy the src memref descriptor
791       // Set the ptr in the destination to the stack space
792       auto srcMemRefType = srcType.cast<MemRefType>();
793       int64_t rank = srcMemRefType.getRank();
794       // ptr = AllocaOp sizeof(MemRefDescriptor)
795       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
796           loc, adaptor.source(), rewriter);
797       // voidptr = BitCastOp srcType* to void*
798       auto voidPtr =
799           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
800               .getResult();
801       // rank = ConstantOp srcRank
802       auto rankVal = rewriter.create<LLVM::ConstantOp>(
803           loc, getIndexType(), rewriter.getIndexAttr(rank));
804       // undef = UndefOp
805       UnrankedMemRefDescriptor memRefDesc =
806           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
807       // d1 = InsertValueOp undef, rank, 0
808       memRefDesc.setRank(rewriter, loc, rankVal);
809       // d2 = InsertValueOp d1, voidptr, 1
810       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
811       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
812 
813     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
814       // Casting from unranked type to ranked.
815       // The operation is assumed to be doing a correct cast. If the destination
816       // type mismatches the unranked the type, it is undefined behavior.
817       UnrankedMemRefDescriptor memRefDesc(adaptor.source());
818       // ptr = ExtractValueOp src, 1
819       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
820       // castPtr = BitCastOp i8* to structTy*
821       auto castPtr =
822           rewriter
823               .create<LLVM::BitcastOp>(
824                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
825               .getResult();
826       // struct = LoadOp castPtr
827       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
828       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
829     } else {
830       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
831     }
832   }
833 };
834 
835 /// Pattern to lower a `memref.copy` to llvm.
836 ///
837 /// For memrefs with identity layouts, the copy is lowered to the llvm
838 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
839 /// to the generic `MemrefCopyFn`.
840 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
841   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
842 
843   LogicalResult
844   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
845                           ConversionPatternRewriter &rewriter) const {
846     auto loc = op.getLoc();
847     auto srcType = op.source().getType().dyn_cast<MemRefType>();
848 
849     MemRefDescriptor srcDesc(adaptor.source());
850 
851     // Compute number of elements.
852     Value numElements = rewriter.create<LLVM::ConstantOp>(
853         loc, getIndexType(), rewriter.getIndexAttr(1));
854     for (int pos = 0; pos < srcType.getRank(); ++pos) {
855       auto size = srcDesc.size(rewriter, loc, pos);
856       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
857     }
858 
859     // Get element size.
860     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
861     // Compute total.
862     Value totalSize =
863         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
864 
865     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
866     Value srcOffset = srcDesc.offset(rewriter, loc);
867     Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
868                                                 srcBasePtr, srcOffset);
869     MemRefDescriptor targetDesc(adaptor.target());
870     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
871     Value targetOffset = targetDesc.offset(rewriter, loc);
872     Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
873                                                    targetBasePtr, targetOffset);
874     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
875         loc, typeConverter->convertType(rewriter.getI1Type()),
876         rewriter.getBoolAttr(false));
877     rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
878                                     isVolatile);
879     rewriter.eraseOp(op);
880 
881     return success();
882   }
883 
884   LogicalResult
885   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
886                              ConversionPatternRewriter &rewriter) const {
887     auto loc = op.getLoc();
888     auto srcType = op.source().getType().cast<BaseMemRefType>();
889     auto targetType = op.target().getType().cast<BaseMemRefType>();
890 
891     // First make sure we have an unranked memref descriptor representation.
892     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
893       auto rank = rewriter.create<LLVM::ConstantOp>(
894           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
895       auto *typeConverter = getTypeConverter();
896       auto ptr =
897           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
898       auto voidPtr =
899           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
900               .getResult();
901       auto unrankedType =
902           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
903       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
904                                             unrankedType,
905                                             ValueRange{rank, voidPtr});
906     };
907 
908     Value unrankedSource = srcType.hasRank()
909                                ? makeUnranked(adaptor.source(), srcType)
910                                : adaptor.source();
911     Value unrankedTarget = targetType.hasRank()
912                                ? makeUnranked(adaptor.target(), targetType)
913                                : adaptor.target();
914 
915     // Now promote the unranked descriptors to the stack.
916     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
917                                                  rewriter.getIndexAttr(1));
918     auto promote = [&](Value desc) {
919       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
920       auto allocated =
921           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
922       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
923       return allocated;
924     };
925 
926     auto sourcePtr = promote(unrankedSource);
927     auto targetPtr = promote(unrankedTarget);
928 
929     unsigned typeSize =
930         mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
931     auto elemSize = rewriter.create<LLVM::ConstantOp>(
932         loc, getIndexType(), rewriter.getIndexAttr(typeSize));
933     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
934         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
935     rewriter.create<LLVM::CallOp>(loc, copyFn,
936                                   ValueRange{elemSize, sourcePtr, targetPtr});
937     rewriter.eraseOp(op);
938 
939     return success();
940   }
941 
942   LogicalResult
943   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
944                   ConversionPatternRewriter &rewriter) const override {
945     auto srcType = op.source().getType().cast<BaseMemRefType>();
946     auto targetType = op.target().getType().cast<BaseMemRefType>();
947 
948     auto isContiguousMemrefType = [](BaseMemRefType type) {
949       auto memrefType = type.dyn_cast<mlir::MemRefType>();
950       // We can use memcpy for memrefs if they have an identity layout or are
951       // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
952       // special case handled by memrefCopy.
953       return memrefType &&
954              (memrefType.getLayout().isIdentity() ||
955               (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
956                isStaticShapeAndContiguousRowMajor(memrefType)));
957     };
958 
959     if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
960       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
961 
962     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
963   }
964 };
965 
966 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
967 /// memref type. In unranked case, the fields are extracted from the underlying
968 /// ranked descriptor.
969 static void extractPointersAndOffset(Location loc,
970                                      ConversionPatternRewriter &rewriter,
971                                      LLVMTypeConverter &typeConverter,
972                                      Value originalOperand,
973                                      Value convertedOperand,
974                                      Value *allocatedPtr, Value *alignedPtr,
975                                      Value *offset = nullptr) {
976   Type operandType = originalOperand.getType();
977   if (operandType.isa<MemRefType>()) {
978     MemRefDescriptor desc(convertedOperand);
979     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
980     *alignedPtr = desc.alignedPtr(rewriter, loc);
981     if (offset != nullptr)
982       *offset = desc.offset(rewriter, loc);
983     return;
984   }
985 
986   unsigned memorySpace =
987       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
988   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
989   Type llvmElementType = typeConverter.convertType(elementType);
990   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
991       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
992 
993   // Extract pointer to the underlying ranked memref descriptor and cast it to
994   // ElemType**.
995   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
996   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
997 
998   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
999       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
1000   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1001       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1002   if (offset != nullptr) {
1003     *offset = UnrankedMemRefDescriptor::offset(
1004         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1005   }
1006 }
1007 
1008 struct MemRefReinterpretCastOpLowering
1009     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1010   using ConvertOpToLLVMPattern<
1011       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1012 
1013   LogicalResult
1014   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1015                   ConversionPatternRewriter &rewriter) const override {
1016     Type srcType = castOp.source().getType();
1017 
1018     Value descriptor;
1019     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1020                                                adaptor, &descriptor)))
1021       return failure();
1022     rewriter.replaceOp(castOp, {descriptor});
1023     return success();
1024   }
1025 
1026 private:
1027   LogicalResult convertSourceMemRefToDescriptor(
1028       ConversionPatternRewriter &rewriter, Type srcType,
1029       memref::ReinterpretCastOp castOp,
1030       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1031     MemRefType targetMemRefType =
1032         castOp.getResult().getType().cast<MemRefType>();
1033     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1034                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1035     if (!llvmTargetDescriptorTy)
1036       return failure();
1037 
1038     // Create descriptor.
1039     Location loc = castOp.getLoc();
1040     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1041 
1042     // Set allocated and aligned pointers.
1043     Value allocatedPtr, alignedPtr;
1044     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1045                              castOp.source(), adaptor.source(), &allocatedPtr,
1046                              &alignedPtr);
1047     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1048     desc.setAlignedPtr(rewriter, loc, alignedPtr);
1049 
1050     // Set offset.
1051     if (castOp.isDynamicOffset(0))
1052       desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
1053     else
1054       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1055 
1056     // Set sizes and strides.
1057     unsigned dynSizeId = 0;
1058     unsigned dynStrideId = 0;
1059     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1060       if (castOp.isDynamicSize(i))
1061         desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
1062       else
1063         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1064 
1065       if (castOp.isDynamicStride(i))
1066         desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
1067       else
1068         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1069     }
1070     *descriptor = desc;
1071     return success();
1072   }
1073 };
1074 
1075 struct MemRefReshapeOpLowering
1076     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1077   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1078 
1079   LogicalResult
1080   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1081                   ConversionPatternRewriter &rewriter) const override {
1082     Type srcType = reshapeOp.source().getType();
1083 
1084     Value descriptor;
1085     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1086                                                adaptor, &descriptor)))
1087       return failure();
1088     rewriter.replaceOp(reshapeOp, {descriptor});
1089     return success();
1090   }
1091 
1092 private:
1093   LogicalResult
1094   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1095                                   Type srcType, memref::ReshapeOp reshapeOp,
1096                                   memref::ReshapeOp::Adaptor adaptor,
1097                                   Value *descriptor) const {
1098     auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
1099     if (shapeMemRefType.hasStaticShape()) {
1100       MemRefType targetMemRefType =
1101           reshapeOp.getResult().getType().cast<MemRefType>();
1102       auto llvmTargetDescriptorTy =
1103           typeConverter->convertType(targetMemRefType)
1104               .dyn_cast_or_null<LLVM::LLVMStructType>();
1105       if (!llvmTargetDescriptorTy)
1106         return failure();
1107 
1108       // Create descriptor.
1109       Location loc = reshapeOp.getLoc();
1110       auto desc =
1111           MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1112 
1113       // Set allocated and aligned pointers.
1114       Value allocatedPtr, alignedPtr;
1115       extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1116                                reshapeOp.source(), adaptor.source(),
1117                                &allocatedPtr, &alignedPtr);
1118       desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1119       desc.setAlignedPtr(rewriter, loc, alignedPtr);
1120 
1121       // Extract the offset and strides from the type.
1122       int64_t offset;
1123       SmallVector<int64_t> strides;
1124       if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1125         return rewriter.notifyMatchFailure(
1126             reshapeOp, "failed to get stride and offset exprs");
1127 
1128       if (!isStaticStrideOrOffset(offset))
1129         return rewriter.notifyMatchFailure(reshapeOp,
1130                                            "dynamic offset is unsupported");
1131 
1132       desc.setConstantOffset(rewriter, loc, offset);
1133 
1134       assert(targetMemRefType.getLayout().isIdentity() &&
1135              "Identity layout map is a precondition of a valid reshape op");
1136 
1137       Value stride = nullptr;
1138       int64_t targetRank = targetMemRefType.getRank();
1139       for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1140         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1141           // If the stride for this dimension is dynamic, then use the product
1142           // of the sizes of the inner dimensions.
1143           stride = createIndexConstant(rewriter, loc, strides[i]);
1144         } else if (!stride) {
1145           // `stride` is null only in the first iteration of the loop.  However,
1146           // since the target memref has an identity layout, we can safely set
1147           // the innermost stride to 1.
1148           stride = createIndexConstant(rewriter, loc, 1);
1149         }
1150 
1151         Value dimSize;
1152         int64_t size = targetMemRefType.getDimSize(i);
1153         // If the size of this dimension is dynamic, then load it at runtime
1154         // from the shape operand.
1155         if (!ShapedType::isDynamic(size)) {
1156           dimSize = createIndexConstant(rewriter, loc, size);
1157         } else {
1158           Value shapeOp = reshapeOp.shape();
1159           Value index = createIndexConstant(rewriter, loc, i);
1160           dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1161         }
1162 
1163         desc.setSize(rewriter, loc, i, dimSize);
1164         desc.setStride(rewriter, loc, i, stride);
1165 
1166         // Prepare the stride value for the next dimension.
1167         stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1168       }
1169 
1170       *descriptor = desc;
1171       return success();
1172     }
1173 
1174     // The shape is a rank-1 tensor with unknown length.
1175     Location loc = reshapeOp.getLoc();
1176     MemRefDescriptor shapeDesc(adaptor.shape());
1177     Value resultRank = shapeDesc.size(rewriter, loc, 0);
1178 
1179     // Extract address space and element type.
1180     auto targetType =
1181         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1182     unsigned addressSpace = targetType.getMemorySpaceAsInt();
1183     Type elementType = targetType.getElementType();
1184 
1185     // Create the unranked memref descriptor that holds the ranked one. The
1186     // inner descriptor is allocated on stack.
1187     auto targetDesc = UnrankedMemRefDescriptor::undef(
1188         rewriter, loc, typeConverter->convertType(targetType));
1189     targetDesc.setRank(rewriter, loc, resultRank);
1190     SmallVector<Value, 4> sizes;
1191     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1192                                            targetDesc, sizes);
1193     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1194         loc, getVoidPtrType(), sizes.front(), llvm::None);
1195     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1196 
1197     // Extract pointers and offset from the source memref.
1198     Value allocatedPtr, alignedPtr, offset;
1199     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1200                              reshapeOp.source(), adaptor.source(),
1201                              &allocatedPtr, &alignedPtr, &offset);
1202 
1203     // Set pointers and offset.
1204     Type llvmElementType = typeConverter->convertType(elementType);
1205     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1206         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1207     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1208                                               elementPtrPtrType, allocatedPtr);
1209     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1210                                             underlyingDescPtr,
1211                                             elementPtrPtrType, alignedPtr);
1212     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1213                                         underlyingDescPtr, elementPtrPtrType,
1214                                         offset);
1215 
1216     // Use the offset pointer as base for further addressing. Copy over the new
1217     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1218     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1219         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1220         elementPtrPtrType);
1221     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1222         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1223     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1224     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1225     Value resultRankMinusOne =
1226         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1227 
1228     Block *initBlock = rewriter.getInsertionBlock();
1229     Type indexType = getTypeConverter()->getIndexType();
1230     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1231 
1232     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1233                                             {indexType, indexType}, {loc, loc});
1234 
1235     // Move the remaining initBlock ops to condBlock.
1236     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1237     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1238 
1239     rewriter.setInsertionPointToEnd(initBlock);
1240     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1241                                 condBlock);
1242     rewriter.setInsertionPointToStart(condBlock);
1243     Value indexArg = condBlock->getArgument(0);
1244     Value strideArg = condBlock->getArgument(1);
1245 
1246     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1247     Value pred = rewriter.create<LLVM::ICmpOp>(
1248         loc, IntegerType::get(rewriter.getContext(), 1),
1249         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1250 
1251     Block *bodyBlock =
1252         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1253     rewriter.setInsertionPointToStart(bodyBlock);
1254 
1255     // Copy size from shape to descriptor.
1256     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1257     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1258         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1259     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1260     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1261                                       targetSizesBase, indexArg, size);
1262 
1263     // Write stride value and compute next one.
1264     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1265                                         targetStridesBase, indexArg, strideArg);
1266     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1267 
1268     // Decrement loop counter and branch back.
1269     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1270     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1271                                 condBlock);
1272 
1273     Block *remainder =
1274         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1275 
1276     // Hook up the cond exit to the remainder.
1277     rewriter.setInsertionPointToEnd(condBlock);
1278     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1279                                     llvm::None);
1280 
1281     // Reset position to beginning of new remainder block.
1282     rewriter.setInsertionPointToStart(remainder);
1283 
1284     *descriptor = targetDesc;
1285     return success();
1286   }
1287 };
1288 
1289 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1290 /// `Value`s.
1291 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1292                                       Type &llvmIndexType,
1293                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1294   return llvm::to_vector<4>(
1295       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1296         if (auto attr = value.dyn_cast<Attribute>())
1297           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1298         return value.get<Value>();
1299       }));
1300 }
1301 
1302 /// Compute a map that for a given dimension of the expanded type gives the
1303 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1304 /// the `reassocation` maps.
1305 static DenseMap<int64_t, int64_t>
1306 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1307   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1308   for (auto &en : enumerate(reassociation)) {
1309     for (auto dim : en.value())
1310       expandedDimToCollapsedDim[dim] = en.index();
1311   }
1312   return expandedDimToCollapsedDim;
1313 }
1314 
1315 static OpFoldResult
1316 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1317                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1318                          MemRefDescriptor &inDesc,
1319                          ArrayRef<int64_t> inStaticShape,
1320                          ArrayRef<ReassociationIndices> reassocation,
1321                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1322   int64_t outDimSize = outStaticShape[outDimIndex];
1323   if (!ShapedType::isDynamic(outDimSize))
1324     return b.getIndexAttr(outDimSize);
1325 
1326   // Calculate the multiplication of all the out dim sizes except the
1327   // current dim.
1328   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1329   int64_t otherDimSizesMul = 1;
1330   for (auto otherDimIndex : reassocation[inDimIndex]) {
1331     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1332       continue;
1333     int64_t otherDimSize = outStaticShape[otherDimIndex];
1334     assert(!ShapedType::isDynamic(otherDimSize) &&
1335            "single dimension cannot be expanded into multiple dynamic "
1336            "dimensions");
1337     otherDimSizesMul *= otherDimSize;
1338   }
1339 
1340   // outDimSize = inDimSize / otherOutDimSizesMul
1341   int64_t inDimSize = inStaticShape[inDimIndex];
1342   Value inDimSizeDynamic =
1343       ShapedType::isDynamic(inDimSize)
1344           ? inDesc.size(b, loc, inDimIndex)
1345           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1346                                        b.getIndexAttr(inDimSize));
1347   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1348       loc, inDimSizeDynamic,
1349       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1350                                  b.getIndexAttr(otherDimSizesMul)));
1351   return outDimSizeDynamic;
1352 }
1353 
1354 static OpFoldResult getCollapsedOutputDimSize(
1355     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1356     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1357     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1358   if (!ShapedType::isDynamic(outDimSize))
1359     return b.getIndexAttr(outDimSize);
1360 
1361   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1362   Value outDimSizeDynamic = c1;
1363   for (auto inDimIndex : reassocation[outDimIndex]) {
1364     int64_t inDimSize = inStaticShape[inDimIndex];
1365     Value inDimSizeDynamic =
1366         ShapedType::isDynamic(inDimSize)
1367             ? inDesc.size(b, loc, inDimIndex)
1368             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1369                                          b.getIndexAttr(inDimSize));
1370     outDimSizeDynamic =
1371         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1372   }
1373   return outDimSizeDynamic;
1374 }
1375 
1376 static SmallVector<OpFoldResult, 4>
1377 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1378                         ArrayRef<ReassociationIndices> reassociation,
1379                         ArrayRef<int64_t> inStaticShape,
1380                         MemRefDescriptor &inDesc,
1381                         ArrayRef<int64_t> outStaticShape) {
1382   return llvm::to_vector<4>(llvm::map_range(
1383       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1384         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1385                                          outStaticShape[outDimIndex],
1386                                          inStaticShape, inDesc, reassociation);
1387       }));
1388 }
1389 
1390 static SmallVector<OpFoldResult, 4>
1391 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1392                        ArrayRef<ReassociationIndices> reassociation,
1393                        ArrayRef<int64_t> inStaticShape,
1394                        MemRefDescriptor &inDesc,
1395                        ArrayRef<int64_t> outStaticShape) {
1396   DenseMap<int64_t, int64_t> outDimToInDimMap =
1397       getExpandedDimToCollapsedDimMap(reassociation);
1398   return llvm::to_vector<4>(llvm::map_range(
1399       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1400         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1401                                         outStaticShape, inDesc, inStaticShape,
1402                                         reassociation, outDimToInDimMap);
1403       }));
1404 }
1405 
1406 static SmallVector<Value>
1407 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1408                       ArrayRef<ReassociationIndices> reassociation,
1409                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1410                       ArrayRef<int64_t> outStaticShape) {
1411   return outStaticShape.size() < inStaticShape.size()
1412              ? getAsValues(b, loc, llvmIndexType,
1413                            getCollapsedOutputShape(b, loc, llvmIndexType,
1414                                                    reassociation, inStaticShape,
1415                                                    inDesc, outStaticShape))
1416              : getAsValues(b, loc, llvmIndexType,
1417                            getExpandedOutputShape(b, loc, llvmIndexType,
1418                                                   reassociation, inStaticShape,
1419                                                   inDesc, outStaticShape));
1420 }
1421 
1422 static void fillInStridesForExpandedMemDescriptor(
1423     OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
1424     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1425   // See comments for computeExpandedLayoutMap for details on how the strides
1426   // are calculated.
1427   for (auto &en : llvm::enumerate(reassociation)) {
1428     auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
1429     for (auto dstIndex : llvm::reverse(en.value())) {
1430       dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
1431       Value size = dstDesc.size(b, loc, dstIndex);
1432       currentStrideToExpand =
1433           b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
1434     }
1435   }
1436 }
1437 
1438 static void fillInStridesForCollapsedMemDescriptor(
1439     ConversionPatternRewriter &rewriter, Location loc, Operation *op,
1440     TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
1441     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1442   // See comments for computeCollapsedLayoutMap for details on how the strides
1443   // are calculated.
1444   auto srcShape = srcType.getShape();
1445   for (auto &en : llvm::enumerate(reassociation)) {
1446     rewriter.setInsertionPoint(op);
1447     auto dstIndex = en.index();
1448     ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value());
1449     while (srcShape[ref.back()] == 1 && ref.size() > 1)
1450       ref = ref.drop_back();
1451     if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1452       dstDesc.setStride(rewriter, loc, dstIndex,
1453                         srcDesc.stride(rewriter, loc, ref.back()));
1454     } else {
1455       // Iterate over the source strides in reverse order. Skip over the
1456       // dimensions whose size is 1.
1457       // TODO: we should take the minimum stride in the reassociation group
1458       // instead of just the first where the dimension is not 1.
1459       //
1460       // +------------------------------------------------------+
1461       // | curEntry:                                            |
1462       // |   %srcStride = strides[srcIndex]                     |
1463       // |   %neOne = cmp sizes[srcIndex],1                     +--+
1464       // |   cf.cond_br %neOne, continue(%srcStride), nextEntry |  |
1465       // +-------------------------+----------------------------+  |
1466       //                           |                               |
1467       //                           v                               |
1468       //            +-----------------------------+                |
1469       //            | nextEntry:                  |                |
1470       //            |   ...                       +---+            |
1471       //            +--------------+--------------+   |            |
1472       //                           |                  |            |
1473       //                           v                  |            |
1474       //            +-----------------------------+   |            |
1475       //            | nextEntry:                  |   |            |
1476       //            |   ...                       |   |            |
1477       //            +--------------+--------------+   |   +--------+
1478       //                           |                  |   |
1479       //                           v                  v   v
1480       //   +--------------------------------------------------+
1481       //   | continue(%newStride):                            |
1482       //   |   %newMemRefDes = setStride(%newStride,dstIndex) |
1483       //   +--------------------------------------------------+
1484       OpBuilder::InsertionGuard guard(rewriter);
1485       Block *initBlock = rewriter.getInsertionBlock();
1486       Block *continueBlock =
1487           rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
1488       continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
1489       rewriter.setInsertionPointToStart(continueBlock);
1490       dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
1491 
1492       Block *curEntryBlock = initBlock;
1493       Block *nextEntryBlock;
1494       for (auto srcIndex : llvm::reverse(ref)) {
1495         if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
1496           continue;
1497         rewriter.setInsertionPointToEnd(curEntryBlock);
1498         Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
1499         if (srcIndex == ref.front()) {
1500           rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
1501           break;
1502         }
1503         Value one = rewriter.create<LLVM::ConstantOp>(
1504             loc, typeConverter->convertType(rewriter.getI64Type()),
1505             rewriter.getI32IntegerAttr(1));
1506         Value predNeOne = rewriter.create<LLVM::ICmpOp>(
1507             loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
1508             one);
1509         {
1510           OpBuilder::InsertionGuard guard(rewriter);
1511           nextEntryBlock = rewriter.createBlock(
1512               initBlock->getParent(), Region::iterator(continueBlock), {});
1513         }
1514         rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
1515                                         srcStride, nextEntryBlock, llvm::None);
1516         curEntryBlock = nextEntryBlock;
1517       }
1518     }
1519   }
1520 }
1521 
1522 static void fillInDynamicStridesForMemDescriptor(
1523     ConversionPatternRewriter &b, Location loc, Operation *op,
1524     TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
1525     MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
1526     ArrayRef<ReassociationIndices> reassociation) {
1527   if (srcType.getRank() > dstType.getRank())
1528     fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
1529                                            srcDesc, dstDesc, reassociation);
1530   else
1531     fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
1532                                           reassociation);
1533 }
1534 
1535 // ReshapeOp creates a new view descriptor of the proper rank.
1536 // For now, the only conversion supported is for target MemRef with static sizes
1537 // and strides.
1538 template <typename ReshapeOp>
1539 class ReassociatingReshapeOpConversion
1540     : public ConvertOpToLLVMPattern<ReshapeOp> {
1541 public:
1542   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1543   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1544 
1545   LogicalResult
1546   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1547                   ConversionPatternRewriter &rewriter) const override {
1548     MemRefType dstType = reshapeOp.getResultType();
1549     MemRefType srcType = reshapeOp.getSrcType();
1550 
1551     int64_t offset;
1552     SmallVector<int64_t, 4> strides;
1553     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1554       return rewriter.notifyMatchFailure(
1555           reshapeOp, "failed to get stride and offset exprs");
1556     }
1557 
1558     MemRefDescriptor srcDesc(adaptor.src());
1559     Location loc = reshapeOp->getLoc();
1560     auto dstDesc = MemRefDescriptor::undef(
1561         rewriter, loc, this->typeConverter->convertType(dstType));
1562     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1563     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1564     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1565 
1566     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1567     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1568     Type llvmIndexType =
1569         this->typeConverter->convertType(rewriter.getIndexType());
1570     SmallVector<Value> dstShape = getDynamicOutputShape(
1571         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1572         srcStaticShape, srcDesc, dstStaticShape);
1573     for (auto &en : llvm::enumerate(dstShape))
1574       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1575 
1576     if (llvm::all_of(strides, isStaticStrideOrOffset)) {
1577       for (auto &en : llvm::enumerate(strides))
1578         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1579     } else if (srcType.getLayout().isIdentity() &&
1580                dstType.getLayout().isIdentity()) {
1581       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1582                                                    rewriter.getIndexAttr(1));
1583       Value stride = c1;
1584       for (auto dimIndex :
1585            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1586         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1587         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1588       }
1589     } else {
1590       // There could be mixed static/dynamic strides. For simplicity, we
1591       // recompute all strides if there is at least one dynamic stride.
1592       fillInDynamicStridesForMemDescriptor(
1593           rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
1594           srcDesc, dstDesc, reshapeOp.getReassociationIndices());
1595     }
1596     rewriter.replaceOp(reshapeOp, {dstDesc});
1597     return success();
1598   }
1599 };
1600 
1601 /// Conversion pattern that transforms a subview op into:
1602 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1603 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1604 ///      and stride.
1605 /// The subview op is replaced by the descriptor.
1606 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1607   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1608 
1609   LogicalResult
1610   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1611                   ConversionPatternRewriter &rewriter) const override {
1612     auto loc = subViewOp.getLoc();
1613 
1614     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
1615     auto sourceElementTy =
1616         typeConverter->convertType(sourceMemRefType.getElementType());
1617 
1618     auto viewMemRefType = subViewOp.getType();
1619     auto inferredType = memref::SubViewOp::inferResultType(
1620                             subViewOp.getSourceType(),
1621                             extractFromI64ArrayAttr(subViewOp.static_offsets()),
1622                             extractFromI64ArrayAttr(subViewOp.static_sizes()),
1623                             extractFromI64ArrayAttr(subViewOp.static_strides()))
1624                             .cast<MemRefType>();
1625     auto targetElementTy =
1626         typeConverter->convertType(viewMemRefType.getElementType());
1627     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1628     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1629         !LLVM::isCompatibleType(sourceElementTy) ||
1630         !LLVM::isCompatibleType(targetElementTy) ||
1631         !LLVM::isCompatibleType(targetDescTy))
1632       return failure();
1633 
1634     // Extract the offset and strides from the type.
1635     int64_t offset;
1636     SmallVector<int64_t, 4> strides;
1637     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1638     if (failed(successStrides))
1639       return failure();
1640 
1641     // Create the descriptor.
1642     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1643       return failure();
1644     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1645     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1646 
1647     // Copy the buffer pointer from the old descriptor to the new one.
1648     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1649     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1650         loc,
1651         LLVM::LLVMPointerType::get(targetElementTy,
1652                                    viewMemRefType.getMemorySpaceAsInt()),
1653         extracted);
1654     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1655 
1656     // Copy the aligned pointer from the old descriptor to the new one.
1657     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1658     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1659         loc,
1660         LLVM::LLVMPointerType::get(targetElementTy,
1661                                    viewMemRefType.getMemorySpaceAsInt()),
1662         extracted);
1663     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1664 
1665     size_t inferredShapeRank = inferredType.getRank();
1666     size_t resultShapeRank = viewMemRefType.getRank();
1667 
1668     // Extract strides needed to compute offset.
1669     SmallVector<Value, 4> strideValues;
1670     strideValues.reserve(inferredShapeRank);
1671     for (unsigned i = 0; i < inferredShapeRank; ++i)
1672       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1673 
1674     // Offset.
1675     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1676     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1677       targetMemRef.setConstantOffset(rewriter, loc, offset);
1678     } else {
1679       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1680       // `inferredShapeRank` may be larger than the number of offset operands
1681       // because of trailing semantics. In this case, the offset is guaranteed
1682       // to be interpreted as 0 and we can just skip the extra dimensions.
1683       for (unsigned i = 0, e = std::min(inferredShapeRank,
1684                                         subViewOp.getMixedOffsets().size());
1685            i < e; ++i) {
1686         Value offset =
1687             // TODO: need OpFoldResult ODS adaptor to clean this up.
1688             subViewOp.isDynamicOffset(i)
1689                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1690                 : rewriter.create<LLVM::ConstantOp>(
1691                       loc, llvmIndexType,
1692                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1693         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1694         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1695       }
1696       targetMemRef.setOffset(rewriter, loc, baseOffset);
1697     }
1698 
1699     // Update sizes and strides.
1700     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1701     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1702     assert(mixedSizes.size() == mixedStrides.size() &&
1703            "expected sizes and strides of equal length");
1704     llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1705     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1706          i >= 0 && j >= 0; --i) {
1707       if (unusedDims.test(i))
1708         continue;
1709 
1710       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1711       // In this case, the size is guaranteed to be interpreted as Dim and the
1712       // stride as 1.
1713       Value size, stride;
1714       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1715         // If the static size is available, use it directly. This is similar to
1716         // the folding of dim(constant-op) but removes the need for dim to be
1717         // aware of LLVM constants and for this pass to be aware of std
1718         // constants.
1719         int64_t staticSize =
1720             subViewOp.source().getType().cast<MemRefType>().getShape()[i];
1721         if (staticSize != ShapedType::kDynamicSize) {
1722           size = rewriter.create<LLVM::ConstantOp>(
1723               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1724         } else {
1725           Value pos = rewriter.create<LLVM::ConstantOp>(
1726               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1727           Value dim =
1728               rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos);
1729           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1730               loc, llvmIndexType, dim);
1731           size = cast.getResult(0);
1732         }
1733         stride = rewriter.create<LLVM::ConstantOp>(
1734             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1735       } else {
1736         // TODO: need OpFoldResult ODS adaptor to clean this up.
1737         size =
1738             subViewOp.isDynamicSize(i)
1739                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1740                 : rewriter.create<LLVM::ConstantOp>(
1741                       loc, llvmIndexType,
1742                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1743         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1744           stride = rewriter.create<LLVM::ConstantOp>(
1745               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1746         } else {
1747           stride =
1748               subViewOp.isDynamicStride(i)
1749                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1750                   : rewriter.create<LLVM::ConstantOp>(
1751                         loc, llvmIndexType,
1752                         rewriter.getI64IntegerAttr(
1753                             subViewOp.getStaticStride(i)));
1754           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1755         }
1756       }
1757       targetMemRef.setSize(rewriter, loc, j, size);
1758       targetMemRef.setStride(rewriter, loc, j, stride);
1759       j--;
1760     }
1761 
1762     rewriter.replaceOp(subViewOp, {targetMemRef});
1763     return success();
1764   }
1765 };
1766 
1767 /// Conversion pattern that transforms a transpose op into:
1768 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1769 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1770 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1771 ///      and stride. Size and stride are permutations of the original values.
1772 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1773 /// The transpose op is replaced by the alloca'ed pointer.
1774 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1775 public:
1776   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1777 
1778   LogicalResult
1779   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1780                   ConversionPatternRewriter &rewriter) const override {
1781     auto loc = transposeOp.getLoc();
1782     MemRefDescriptor viewMemRef(adaptor.in());
1783 
1784     // No permutation, early exit.
1785     if (transposeOp.permutation().isIdentity())
1786       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1787 
1788     auto targetMemRef = MemRefDescriptor::undef(
1789         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1790 
1791     // Copy the base and aligned pointers from the old descriptor to the new
1792     // one.
1793     targetMemRef.setAllocatedPtr(rewriter, loc,
1794                                  viewMemRef.allocatedPtr(rewriter, loc));
1795     targetMemRef.setAlignedPtr(rewriter, loc,
1796                                viewMemRef.alignedPtr(rewriter, loc));
1797 
1798     // Copy the offset pointer from the old descriptor to the new one.
1799     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1800 
1801     // Iterate over the dimensions and apply size/stride permutation.
1802     for (const auto &en :
1803          llvm::enumerate(transposeOp.permutation().getResults())) {
1804       int sourcePos = en.index();
1805       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1806       targetMemRef.setSize(rewriter, loc, targetPos,
1807                            viewMemRef.size(rewriter, loc, sourcePos));
1808       targetMemRef.setStride(rewriter, loc, targetPos,
1809                              viewMemRef.stride(rewriter, loc, sourcePos));
1810     }
1811 
1812     rewriter.replaceOp(transposeOp, {targetMemRef});
1813     return success();
1814   }
1815 };
1816 
1817 /// Conversion pattern that transforms an op into:
1818 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1819 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1820 ///      and stride.
1821 /// The view op is replaced by the descriptor.
1822 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1823   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1824 
1825   // Build and return the value for the idx^th shape dimension, either by
1826   // returning the constant shape dimension or counting the proper dynamic size.
1827   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1828                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1829                 unsigned idx) const {
1830     assert(idx < shape.size());
1831     if (!ShapedType::isDynamic(shape[idx]))
1832       return createIndexConstant(rewriter, loc, shape[idx]);
1833     // Count the number of dynamic dims in range [0, idx]
1834     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1835       return ShapedType::isDynamic(v);
1836     });
1837     return dynamicSizes[nDynamic];
1838   }
1839 
1840   // Build and return the idx^th stride, either by returning the constant stride
1841   // or by computing the dynamic stride from the current `runningStride` and
1842   // `nextSize`. The caller should keep a running stride and update it with the
1843   // result returned by this function.
1844   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1845                   ArrayRef<int64_t> strides, Value nextSize,
1846                   Value runningStride, unsigned idx) const {
1847     assert(idx < strides.size());
1848     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1849       return createIndexConstant(rewriter, loc, strides[idx]);
1850     if (nextSize)
1851       return runningStride
1852                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1853                  : nextSize;
1854     assert(!runningStride);
1855     return createIndexConstant(rewriter, loc, 1);
1856   }
1857 
1858   LogicalResult
1859   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1860                   ConversionPatternRewriter &rewriter) const override {
1861     auto loc = viewOp.getLoc();
1862 
1863     auto viewMemRefType = viewOp.getType();
1864     auto targetElementTy =
1865         typeConverter->convertType(viewMemRefType.getElementType());
1866     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1867     if (!targetDescTy || !targetElementTy ||
1868         !LLVM::isCompatibleType(targetElementTy) ||
1869         !LLVM::isCompatibleType(targetDescTy))
1870       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1871              failure();
1872 
1873     int64_t offset;
1874     SmallVector<int64_t, 4> strides;
1875     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1876     if (failed(successStrides))
1877       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1878     assert(offset == 0 && "expected offset to be 0");
1879 
1880     // Target memref must be contiguous in memory (innermost stride is 1), or
1881     // empty (special case when at least one of the memref dimensions is 0).
1882     if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1883       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1884              failure();
1885 
1886     // Create the descriptor.
1887     MemRefDescriptor sourceMemRef(adaptor.source());
1888     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1889 
1890     // Field 1: Copy the allocated pointer, used for malloc/free.
1891     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1892     auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
1893     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1894         loc,
1895         LLVM::LLVMPointerType::get(targetElementTy,
1896                                    srcMemRefType.getMemorySpaceAsInt()),
1897         allocatedPtr);
1898     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1899 
1900     // Field 2: Copy the actual aligned pointer to payload.
1901     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1902     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
1903                                               alignedPtr, adaptor.byte_shift());
1904     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1905         loc,
1906         LLVM::LLVMPointerType::get(targetElementTy,
1907                                    srcMemRefType.getMemorySpaceAsInt()),
1908         alignedPtr);
1909     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1910 
1911     // Field 3: The offset in the resulting type must be 0. This is because of
1912     // the type change: an offset on srcType* may not be expressible as an
1913     // offset on dstType*.
1914     targetMemRef.setOffset(rewriter, loc,
1915                            createIndexConstant(rewriter, loc, offset));
1916 
1917     // Early exit for 0-D corner case.
1918     if (viewMemRefType.getRank() == 0)
1919       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1920 
1921     // Fields 4 and 5: Update sizes and strides.
1922     Value stride = nullptr, nextSize = nullptr;
1923     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1924       // Update size.
1925       Value size =
1926           getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
1927       targetMemRef.setSize(rewriter, loc, i, size);
1928       // Update stride.
1929       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1930       targetMemRef.setStride(rewriter, loc, i, stride);
1931       nextSize = size;
1932     }
1933 
1934     rewriter.replaceOp(viewOp, {targetMemRef});
1935     return success();
1936   }
1937 };
1938 
1939 //===----------------------------------------------------------------------===//
1940 // AtomicRMWOpLowering
1941 //===----------------------------------------------------------------------===//
1942 
1943 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1944 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1945 static Optional<LLVM::AtomicBinOp>
1946 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1947   switch (atomicOp.kind()) {
1948   case arith::AtomicRMWKind::addf:
1949     return LLVM::AtomicBinOp::fadd;
1950   case arith::AtomicRMWKind::addi:
1951     return LLVM::AtomicBinOp::add;
1952   case arith::AtomicRMWKind::assign:
1953     return LLVM::AtomicBinOp::xchg;
1954   case arith::AtomicRMWKind::maxs:
1955     return LLVM::AtomicBinOp::max;
1956   case arith::AtomicRMWKind::maxu:
1957     return LLVM::AtomicBinOp::umax;
1958   case arith::AtomicRMWKind::mins:
1959     return LLVM::AtomicBinOp::min;
1960   case arith::AtomicRMWKind::minu:
1961     return LLVM::AtomicBinOp::umin;
1962   case arith::AtomicRMWKind::ori:
1963     return LLVM::AtomicBinOp::_or;
1964   case arith::AtomicRMWKind::andi:
1965     return LLVM::AtomicBinOp::_and;
1966   default:
1967     return llvm::None;
1968   }
1969   llvm_unreachable("Invalid AtomicRMWKind");
1970 }
1971 
1972 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1973   using Base::Base;
1974 
1975   LogicalResult
1976   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1977                   ConversionPatternRewriter &rewriter) const override {
1978     if (failed(match(atomicOp)))
1979       return failure();
1980     auto maybeKind = matchSimpleAtomicOp(atomicOp);
1981     if (!maybeKind)
1982       return failure();
1983     auto resultType = adaptor.value().getType();
1984     auto memRefType = atomicOp.getMemRefType();
1985     auto dataPtr =
1986         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
1987                              adaptor.indices(), rewriter);
1988     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1989         atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
1990         LLVM::AtomicOrdering::acq_rel);
1991     return success();
1992   }
1993 };
1994 
1995 } // namespace
1996 
1997 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
1998                                                   RewritePatternSet &patterns) {
1999   // clang-format off
2000   patterns.add<
2001       AllocaOpLowering,
2002       AllocaScopeOpLowering,
2003       AtomicRMWOpLowering,
2004       AssumeAlignmentOpLowering,
2005       DimOpLowering,
2006       GenericAtomicRMWOpLowering,
2007       GlobalMemrefOpLowering,
2008       GetGlobalMemrefOpLowering,
2009       LoadOpLowering,
2010       MemRefCastOpLowering,
2011       MemRefCopyOpLowering,
2012       MemRefReinterpretCastOpLowering,
2013       MemRefReshapeOpLowering,
2014       PrefetchOpLowering,
2015       RankOpLowering,
2016       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2017       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2018       StoreOpLowering,
2019       SubViewOpLowering,
2020       TransposeOpLowering,
2021       ViewOpLowering>(converter);
2022   // clang-format on
2023   auto allocLowering = converter.getOptions().allocLowering;
2024   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
2025     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
2026   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2027     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
2028 }
2029 
2030 namespace {
2031 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
2032   MemRefToLLVMPass() = default;
2033 
2034   void runOnOperation() override {
2035     Operation *op = getOperation();
2036     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2037     LowerToLLVMOptions options(&getContext(),
2038                                dataLayoutAnalysis.getAtOrAbove(op));
2039     options.allocLowering =
2040         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
2041                          : LowerToLLVMOptions::AllocLowering::Malloc);
2042     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2043       options.overrideIndexBitwidth(indexBitwidth);
2044 
2045     LLVMTypeConverter typeConverter(&getContext(), options,
2046                                     &dataLayoutAnalysis);
2047     RewritePatternSet patterns(&getContext());
2048     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
2049     LLVMConversionTarget target(getContext());
2050     target.addLegalOp<func::FuncOp>();
2051     if (failed(applyPartialConversion(op, target, std::move(patterns))))
2052       signalPassFailure();
2053   }
2054 };
2055 } // namespace
2056 
2057 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
2058   return std::make_unique<MemRefToLLVMPass>();
2059 }
2060