1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 
29 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
30   return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
31 }
32 
33 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
34   AllocOpLowering(LLVMTypeConverter &converter)
35       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
36                                 converter) {}
37 
38   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
39                                           Location loc, Value sizeBytes,
40                                           Operation *op) const override {
41     // Heap allocations.
42     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
43     MemRefType memRefType = allocOp.getType();
44 
45     Value alignment;
46     if (auto alignmentAttr = allocOp.alignment()) {
47       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
48     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
49       // In the case where no alignment is specified, we may want to override
50       // `malloc's` behavior. `malloc` typically aligns at the size of the
51       // biggest scalar on a target HW. For non-scalars, use the natural
52       // alignment of the LLVM type given by the LLVM DataLayout.
53       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
54     }
55 
56     if (alignment) {
57       // Adjust the allocation size to consider alignment.
58       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
59     }
60 
61     // Allocate the underlying buffer and store a pointer to it in the MemRef
62     // descriptor.
63     Type elementPtrType = this->getElementPtrType(memRefType);
64     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
65         allocOp->getParentOfType<ModuleOp>(), getIndexType());
66     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
67                                   getVoidPtrType());
68     Value allocatedPtr =
69         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
70 
71     Value alignedPtr = allocatedPtr;
72     if (alignment) {
73       // Compute the aligned type pointer.
74       Value allocatedInt =
75           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
76       Value alignmentInt =
77           createAligned(rewriter, loc, allocatedInt, alignment);
78       alignedPtr =
79           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
80     }
81 
82     return std::make_tuple(allocatedPtr, alignedPtr);
83   }
84 };
85 
86 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
87   AlignedAllocOpLowering(LLVMTypeConverter &converter)
88       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
89                                 converter) {}
90 
91   /// Returns the memref's element size in bytes using the data layout active at
92   /// `op`.
93   // TODO: there are other places where this is used. Expose publicly?
94   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
95     const DataLayout *layout = &defaultLayout;
96     if (const DataLayoutAnalysis *analysis =
97             getTypeConverter()->getDataLayoutAnalysis()) {
98       layout = &analysis->getAbove(op);
99     }
100     Type elementType = memRefType.getElementType();
101     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
102       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
103                                                          *layout);
104     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
105       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
106           memRefElementType, *layout);
107     return layout->getTypeSize(elementType);
108   }
109 
110   /// Returns true if the memref size in bytes is known to be a multiple of
111   /// factor assuming the data layout active at `op`.
112   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
113                               Operation *op) const {
114     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
115     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
116       if (ShapedType::isDynamic(type.getDimSize(i)))
117         continue;
118       sizeDivisor = sizeDivisor * type.getDimSize(i);
119     }
120     return sizeDivisor % factor == 0;
121   }
122 
123   /// Returns the alignment to be used for the allocation call itself.
124   /// aligned_alloc requires the allocation size to be a power of two, and the
125   /// allocation size to be a multiple of alignment,
126   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
127     if (Optional<uint64_t> alignment = allocOp.alignment())
128       return *alignment;
129 
130     // Whenever we don't have alignment set, we will use an alignment
131     // consistent with the element type; since the allocation size has to be a
132     // power of two, we will bump to the next power of two if it already isn't.
133     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
134     return std::max(kMinAlignedAllocAlignment,
135                     llvm::PowerOf2Ceil(eltSizeBytes));
136   }
137 
138   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
139                                           Location loc, Value sizeBytes,
140                                           Operation *op) const override {
141     // Heap allocations.
142     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
143     MemRefType memRefType = allocOp.getType();
144     int64_t alignment = getAllocationAlignment(allocOp);
145     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
146 
147     // aligned_alloc requires size to be a multiple of alignment; we will pad
148     // the size to the next multiple if necessary.
149     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
150       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
151 
152     Type elementPtrType = this->getElementPtrType(memRefType);
153     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
154         allocOp->getParentOfType<ModuleOp>(), getIndexType());
155     auto results =
156         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
157                        getVoidPtrType());
158     Value allocatedPtr =
159         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
160 
161     return std::make_tuple(allocatedPtr, allocatedPtr);
162   }
163 
164   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
165   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
166 
167   /// Default layout to use in absence of the corresponding analysis.
168   DataLayout defaultLayout;
169 };
170 
171 // Out of line definition, required till C++17.
172 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
173 
174 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
175   AllocaOpLowering(LLVMTypeConverter &converter)
176       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
177                                 converter) {}
178 
179   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
180   /// is set to null for stack allocations. `accessAlignment` is set if
181   /// alignment is needed post allocation (for eg. in conjunction with malloc).
182   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
183                                           Location loc, Value sizeBytes,
184                                           Operation *op) const override {
185 
186     // With alloca, one gets a pointer to the element type right away.
187     // For stack allocations.
188     auto allocaOp = cast<memref::AllocaOp>(op);
189     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
190 
191     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
192         loc, elementPtrType, sizeBytes,
193         allocaOp.alignment() ? *allocaOp.alignment() : 0);
194 
195     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
196   }
197 };
198 
199 struct AllocaScopeOpLowering
200     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
201   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
202 
203   LogicalResult
204   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
205                   ConversionPatternRewriter &rewriter) const override {
206     OpBuilder::InsertionGuard guard(rewriter);
207     Location loc = allocaScopeOp.getLoc();
208 
209     // Split the current block before the AllocaScopeOp to create the inlining
210     // point.
211     auto *currentBlock = rewriter.getInsertionBlock();
212     auto *remainingOpsBlock =
213         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
214     Block *continueBlock;
215     if (allocaScopeOp.getNumResults() == 0) {
216       continueBlock = remainingOpsBlock;
217     } else {
218       continueBlock = rewriter.createBlock(
219           remainingOpsBlock, allocaScopeOp.getResultTypes(),
220           SmallVector<Location>(allocaScopeOp->getNumResults(),
221                                 allocaScopeOp.getLoc()));
222       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
223     }
224 
225     // Inline body region.
226     Block *beforeBody = &allocaScopeOp.bodyRegion().front();
227     Block *afterBody = &allocaScopeOp.bodyRegion().back();
228     rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
229 
230     // Save stack and then branch into the body of the region.
231     rewriter.setInsertionPointToEnd(currentBlock);
232     auto stackSaveOp =
233         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
234     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
235 
236     // Replace the alloca_scope return with a branch that jumps out of the body.
237     // Stack restore before leaving the body region.
238     rewriter.setInsertionPointToEnd(afterBody);
239     auto returnOp =
240         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
241     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
242         returnOp, returnOp.results(), continueBlock);
243 
244     // Insert stack restore before jumping out the body of the region.
245     rewriter.setInsertionPoint(branchOp);
246     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
247 
248     // Replace the op with values return from the body region.
249     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
250 
251     return success();
252   }
253 };
254 
255 struct AssumeAlignmentOpLowering
256     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
257   using ConvertOpToLLVMPattern<
258       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
259 
260   LogicalResult
261   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
262                   ConversionPatternRewriter &rewriter) const override {
263     Value memref = adaptor.memref();
264     unsigned alignment = op.alignment();
265     auto loc = op.getLoc();
266 
267     MemRefDescriptor memRefDescriptor(memref);
268     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
269 
270     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
271     // the asserted memref.alignedPtr isn't used anywhere else, as the real
272     // users like load/store/views always re-extract memref.alignedPtr as they
273     // get lowered.
274     //
275     // This relies on LLVM's CSE optimization (potentially after SROA), since
276     // after CSE all memref.alignedPtr instances get de-duplicated into the same
277     // pointer SSA value.
278     auto intPtrType =
279         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
280     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
281     Value mask =
282         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
283     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
284     rewriter.create<LLVM::AssumeOp>(
285         loc, rewriter.create<LLVM::ICmpOp>(
286                  loc, LLVM::ICmpPredicate::eq,
287                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
288 
289     rewriter.eraseOp(op);
290     return success();
291   }
292 };
293 
294 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
295 // The memref descriptor being an SSA value, there is no need to clean it up
296 // in any way.
297 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
298   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
299 
300   explicit DeallocOpLowering(LLVMTypeConverter &converter)
301       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
302 
303   LogicalResult
304   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
305                   ConversionPatternRewriter &rewriter) const override {
306     // Insert the `free` declaration if it is not already present.
307     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
308     MemRefDescriptor memref(adaptor.memref());
309     Value casted = rewriter.create<LLVM::BitcastOp>(
310         op.getLoc(), getVoidPtrType(),
311         memref.allocatedPtr(rewriter, op.getLoc()));
312     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
313         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
314     return success();
315   }
316 };
317 
318 // A `dim` is converted to a constant for static sizes and to an access to the
319 // size stored in the memref descriptor for dynamic sizes.
320 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
321   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
322 
323   LogicalResult
324   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
325                   ConversionPatternRewriter &rewriter) const override {
326     Type operandType = dimOp.source().getType();
327     if (operandType.isa<UnrankedMemRefType>()) {
328       rewriter.replaceOp(
329           dimOp, {extractSizeOfUnrankedMemRef(
330                      operandType, dimOp, adaptor.getOperands(), rewriter)});
331 
332       return success();
333     }
334     if (operandType.isa<MemRefType>()) {
335       rewriter.replaceOp(
336           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
337                                             adaptor.getOperands(), rewriter)});
338       return success();
339     }
340     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
341   }
342 
343 private:
344   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
345                                     OpAdaptor adaptor,
346                                     ConversionPatternRewriter &rewriter) const {
347     Location loc = dimOp.getLoc();
348 
349     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
350     auto scalarMemRefType =
351         MemRefType::get({}, unrankedMemRefType.getElementType());
352     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
353 
354     // Extract pointer to the underlying ranked descriptor and bitcast it to a
355     // memref<element_type> descriptor pointer to minimize the number of GEP
356     // operations.
357     UnrankedMemRefDescriptor unrankedDesc(adaptor.source());
358     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
359     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
360         loc,
361         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
362                                    addressSpace),
363         underlyingRankedDesc);
364 
365     // Get pointer to offset field of memref<element_type> descriptor.
366     Type indexPtrTy = LLVM::LLVMPointerType::get(
367         getTypeConverter()->getIndexType(), addressSpace);
368     Value two = rewriter.create<LLVM::ConstantOp>(
369         loc, typeConverter->convertType(rewriter.getI32Type()),
370         rewriter.getI32IntegerAttr(2));
371     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
372         loc, indexPtrTy, scalarMemRefDescPtr,
373         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
374 
375     // The size value that we have to extract can be obtained using GEPop with
376     // `dimOp.index() + 1` index argument.
377     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
378         loc, createIndexConstant(rewriter, loc, 1), adaptor.index());
379     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
380                                                  ValueRange({idxPlusOne}));
381     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
382   }
383 
384   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
385     if (Optional<int64_t> idx = dimOp.getConstantIndex())
386       return idx;
387 
388     if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>())
389       return constantOp.getValue()
390           .cast<IntegerAttr>()
391           .getValue()
392           .getSExtValue();
393 
394     return llvm::None;
395   }
396 
397   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
398                                   OpAdaptor adaptor,
399                                   ConversionPatternRewriter &rewriter) const {
400     Location loc = dimOp.getLoc();
401 
402     // Take advantage if index is constant.
403     MemRefType memRefType = operandType.cast<MemRefType>();
404     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
405       int64_t i = index.getValue();
406       if (memRefType.isDynamicDim(i)) {
407         // extract dynamic size from the memref descriptor.
408         MemRefDescriptor descriptor(adaptor.source());
409         return descriptor.size(rewriter, loc, i);
410       }
411       // Use constant for static size.
412       int64_t dimSize = memRefType.getDimSize(i);
413       return createIndexConstant(rewriter, loc, dimSize);
414     }
415     Value index = adaptor.index();
416     int64_t rank = memRefType.getRank();
417     MemRefDescriptor memrefDescriptor(adaptor.source());
418     return memrefDescriptor.size(rewriter, loc, index, rank);
419   }
420 };
421 
422 /// Common base for load and store operations on MemRefs. Restricts the match
423 /// to supported MemRef types. Provides functionality to emit code accessing a
424 /// specific element of the underlying data buffer.
425 template <typename Derived>
426 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
427   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
428   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
429   using Base = LoadStoreOpLowering<Derived>;
430 
431   LogicalResult match(Derived op) const override {
432     MemRefType type = op.getMemRefType();
433     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
434   }
435 };
436 
437 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
438 /// retried until it succeeds in atomically storing a new value into memory.
439 ///
440 ///      +---------------------------------+
441 ///      |   <code before the AtomicRMWOp> |
442 ///      |   <compute initial %loaded>     |
443 ///      |   cf.br loop(%loaded)              |
444 ///      +---------------------------------+
445 ///             |
446 ///  -------|   |
447 ///  |      v   v
448 ///  |   +--------------------------------+
449 ///  |   | loop(%loaded):                 |
450 ///  |   |   <body contents>              |
451 ///  |   |   %pair = cmpxchg              |
452 ///  |   |   %ok = %pair[0]               |
453 ///  |   |   %new = %pair[1]              |
454 ///  |   |   cf.cond_br %ok, end, loop(%new) |
455 ///  |   +--------------------------------+
456 ///  |          |        |
457 ///  |-----------        |
458 ///                      v
459 ///      +--------------------------------+
460 ///      | end:                           |
461 ///      |   <code after the AtomicRMWOp> |
462 ///      +--------------------------------+
463 ///
464 struct GenericAtomicRMWOpLowering
465     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
466   using Base::Base;
467 
468   LogicalResult
469   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
470                   ConversionPatternRewriter &rewriter) const override {
471     auto loc = atomicOp.getLoc();
472     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
473 
474     // Split the block into initial, loop, and ending parts.
475     auto *initBlock = rewriter.getInsertionBlock();
476     auto *loopBlock = rewriter.createBlock(
477         initBlock->getParent(), std::next(Region::iterator(initBlock)),
478         valueType, loc);
479     auto *endBlock = rewriter.createBlock(
480         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
481 
482     // Operations range to be moved to `endBlock`.
483     auto opsToMoveStart = atomicOp->getIterator();
484     auto opsToMoveEnd = initBlock->back().getIterator();
485 
486     // Compute the loaded value and branch to the loop block.
487     rewriter.setInsertionPointToEnd(initBlock);
488     auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
489     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
490                                         adaptor.indices(), rewriter);
491     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
492     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
493 
494     // Prepare the body of the loop block.
495     rewriter.setInsertionPointToStart(loopBlock);
496 
497     // Clone the GenericAtomicRMWOp region and extract the result.
498     auto loopArgument = loopBlock->getArgument(0);
499     BlockAndValueMapping mapping;
500     mapping.map(atomicOp.getCurrentValue(), loopArgument);
501     Block &entryBlock = atomicOp.body().front();
502     for (auto &nestedOp : entryBlock.without_terminator()) {
503       Operation *clone = rewriter.clone(nestedOp, mapping);
504       mapping.map(nestedOp.getResults(), clone->getResults());
505     }
506     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
507 
508     // Prepare the epilog of the loop block.
509     // Append the cmpxchg op to the end of the loop block.
510     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
511     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
512     auto boolType = IntegerType::get(rewriter.getContext(), 1);
513     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
514                                                      {valueType, boolType});
515     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
516         loc, pairType, dataPtr, loopArgument, result, successOrdering,
517         failureOrdering);
518     // Extract the %new_loaded and %ok values from the pair.
519     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
520         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
521     Value ok = rewriter.create<LLVM::ExtractValueOp>(
522         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
523 
524     // Conditionally branch to the end or back to the loop depending on %ok.
525     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
526                                     loopBlock, newLoaded);
527 
528     rewriter.setInsertionPointToEnd(endBlock);
529     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
530                  std::next(opsToMoveEnd), rewriter);
531 
532     // The 'result' of the atomic_rmw op is the newly loaded value.
533     rewriter.replaceOp(atomicOp, {newLoaded});
534 
535     return success();
536   }
537 
538 private:
539   // Clones a segment of ops [start, end) and erases the original.
540   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
541                     Block::iterator start, Block::iterator end,
542                     ConversionPatternRewriter &rewriter) const {
543     BlockAndValueMapping mapping;
544     mapping.map(oldResult, newResult);
545     SmallVector<Operation *, 2> opsToErase;
546     for (auto it = start; it != end; ++it) {
547       rewriter.clone(*it, mapping);
548       opsToErase.push_back(&*it);
549     }
550     for (auto *it : opsToErase)
551       rewriter.eraseOp(it);
552   }
553 };
554 
555 /// Returns the LLVM type of the global variable given the memref type `type`.
556 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
557                                           LLVMTypeConverter &typeConverter) {
558   // LLVM type for a global memref will be a multi-dimension array. For
559   // declarations or uninitialized global memrefs, we can potentially flatten
560   // this to a 1D array. However, for memref.global's with an initial value,
561   // we do not intend to flatten the ElementsAttribute when going from std ->
562   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
563   Type elementType = typeConverter.convertType(type.getElementType());
564   Type arrayTy = elementType;
565   // Shape has the outermost dim at index 0, so need to walk it backwards
566   for (int64_t dim : llvm::reverse(type.getShape()))
567     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
568   return arrayTy;
569 }
570 
571 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
572 struct GlobalMemrefOpLowering
573     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
574   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
575 
576   LogicalResult
577   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
578                   ConversionPatternRewriter &rewriter) const override {
579     MemRefType type = global.type();
580     if (!isConvertibleAndHasIdentityMaps(type))
581       return failure();
582 
583     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
584 
585     LLVM::Linkage linkage =
586         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
587 
588     Attribute initialValue = nullptr;
589     if (!global.isExternal() && !global.isUninitialized()) {
590       auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
591       initialValue = elementsAttr;
592 
593       // For scalar memrefs, the global variable created is of the element type,
594       // so unpack the elements attribute to extract the value.
595       if (type.getRank() == 0)
596         initialValue = elementsAttr.getSplatValue<Attribute>();
597     }
598 
599     uint64_t alignment = global.alignment().getValueOr(0);
600 
601     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
602         global, arrayTy, global.constant(), linkage, global.sym_name(),
603         initialValue, alignment, type.getMemorySpaceAsInt());
604     if (!global.isExternal() && global.isUninitialized()) {
605       Block *blk = new Block();
606       newGlobal.getInitializerRegion().push_back(blk);
607       rewriter.setInsertionPointToStart(blk);
608       Value undef[] = {
609           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
610       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
611     }
612     return success();
613   }
614 };
615 
616 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
617 /// the first element stashed into the descriptor. This reuses
618 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
619 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
620   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
621       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
622                                 converter) {}
623 
624   /// Buffer "allocation" for memref.get_global op is getting the address of
625   /// the global variable referenced.
626   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
627                                           Location loc, Value sizeBytes,
628                                           Operation *op) const override {
629     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
630     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
631     unsigned memSpace = type.getMemorySpaceAsInt();
632 
633     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
634     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
635         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
636 
637     // Get the address of the first element in the array by creating a GEP with
638     // the address of the GV as the base, and (rank + 1) number of 0 indices.
639     Type elementType = typeConverter->convertType(type.getElementType());
640     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
641 
642     SmallVector<Value> operands;
643     operands.insert(operands.end(), type.getRank() + 1,
644                     createIndexConstant(rewriter, loc, 0));
645     auto gep =
646         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
647 
648     // We do not expect the memref obtained using `memref.get_global` to be
649     // ever deallocated. Set the allocated pointer to be known bad value to
650     // help debug if that ever happens.
651     auto intPtrType = getIntPtrType(memSpace);
652     Value deadBeefConst =
653         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
654     auto deadBeefPtr =
655         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
656 
657     // Both allocated and aligned pointers are same. We could potentially stash
658     // a nullptr for the allocated pointer since we do not expect any dealloc.
659     return std::make_tuple(deadBeefPtr, gep);
660   }
661 };
662 
663 // Load operation is lowered to obtaining a pointer to the indexed element
664 // and loading it.
665 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
666   using Base::Base;
667 
668   LogicalResult
669   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
670                   ConversionPatternRewriter &rewriter) const override {
671     auto type = loadOp.getMemRefType();
672 
673     Value dataPtr = getStridedElementPtr(
674         loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter);
675     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
676     return success();
677   }
678 };
679 
680 // Store operation is lowered to obtaining a pointer to the indexed element,
681 // and storing the given value to it.
682 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
683   using Base::Base;
684 
685   LogicalResult
686   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
687                   ConversionPatternRewriter &rewriter) const override {
688     auto type = op.getMemRefType();
689 
690     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(),
691                                          adaptor.indices(), rewriter);
692     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr);
693     return success();
694   }
695 };
696 
697 // The prefetch operation is lowered in a way similar to the load operation
698 // except that the llvm.prefetch operation is used for replacement.
699 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
700   using Base::Base;
701 
702   LogicalResult
703   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
704                   ConversionPatternRewriter &rewriter) const override {
705     auto type = prefetchOp.getMemRefType();
706     auto loc = prefetchOp.getLoc();
707 
708     Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(),
709                                          adaptor.indices(), rewriter);
710 
711     // Replace with llvm.prefetch.
712     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
713     auto isWrite = rewriter.create<LLVM::ConstantOp>(
714         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
715     auto localityHint = rewriter.create<LLVM::ConstantOp>(
716         loc, llvmI32Type,
717         rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
718     auto isData = rewriter.create<LLVM::ConstantOp>(
719         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
720 
721     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
722                                                 localityHint, isData);
723     return success();
724   }
725 };
726 
727 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
728   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
729 
730   LogicalResult
731   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
732                   ConversionPatternRewriter &rewriter) const override {
733     Location loc = op.getLoc();
734     Type operandType = op.memref().getType();
735     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
736       UnrankedMemRefDescriptor desc(adaptor.memref());
737       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
738       return success();
739     }
740     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
741       rewriter.replaceOp(
742           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
743       return success();
744     }
745     return failure();
746   }
747 };
748 
749 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
750   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
751 
752   LogicalResult match(memref::CastOp memRefCastOp) const override {
753     Type srcType = memRefCastOp.getOperand().getType();
754     Type dstType = memRefCastOp.getType();
755 
756     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
757     // used for type erasure. For now they must preserve underlying element type
758     // and require source and result type to have the same rank. Therefore,
759     // perform a sanity check that the underlying structs are the same. Once op
760     // semantics are relaxed we can revisit.
761     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
762       return success(typeConverter->convertType(srcType) ==
763                      typeConverter->convertType(dstType));
764 
765     // At least one of the operands is unranked type
766     assert(srcType.isa<UnrankedMemRefType>() ||
767            dstType.isa<UnrankedMemRefType>());
768 
769     // Unranked to unranked cast is disallowed
770     return !(srcType.isa<UnrankedMemRefType>() &&
771              dstType.isa<UnrankedMemRefType>())
772                ? success()
773                : failure();
774   }
775 
776   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
777                ConversionPatternRewriter &rewriter) const override {
778     auto srcType = memRefCastOp.getOperand().getType();
779     auto dstType = memRefCastOp.getType();
780     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
781     auto loc = memRefCastOp.getLoc();
782 
783     // For ranked/ranked case, just keep the original descriptor.
784     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
785       return rewriter.replaceOp(memRefCastOp, {adaptor.source()});
786 
787     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
788       // Casting ranked to unranked memref type
789       // Set the rank in the destination from the memref type
790       // Allocate space on the stack and copy the src memref descriptor
791       // Set the ptr in the destination to the stack space
792       auto srcMemRefType = srcType.cast<MemRefType>();
793       int64_t rank = srcMemRefType.getRank();
794       // ptr = AllocaOp sizeof(MemRefDescriptor)
795       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
796           loc, adaptor.source(), rewriter);
797       // voidptr = BitCastOp srcType* to void*
798       auto voidPtr =
799           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
800               .getResult();
801       // rank = ConstantOp srcRank
802       auto rankVal = rewriter.create<LLVM::ConstantOp>(
803           loc, getIndexType(), rewriter.getIndexAttr(rank));
804       // undef = UndefOp
805       UnrankedMemRefDescriptor memRefDesc =
806           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
807       // d1 = InsertValueOp undef, rank, 0
808       memRefDesc.setRank(rewriter, loc, rankVal);
809       // d2 = InsertValueOp d1, voidptr, 1
810       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
811       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
812 
813     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
814       // Casting from unranked type to ranked.
815       // The operation is assumed to be doing a correct cast. If the destination
816       // type mismatches the unranked the type, it is undefined behavior.
817       UnrankedMemRefDescriptor memRefDesc(adaptor.source());
818       // ptr = ExtractValueOp src, 1
819       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
820       // castPtr = BitCastOp i8* to structTy*
821       auto castPtr =
822           rewriter
823               .create<LLVM::BitcastOp>(
824                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
825               .getResult();
826       // struct = LoadOp castPtr
827       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
828       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
829     } else {
830       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
831     }
832   }
833 };
834 
835 /// Pattern to lower a `memref.copy` to llvm.
836 ///
837 /// For memrefs with identity layouts, the copy is lowered to the llvm
838 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
839 /// to the generic `MemrefCopyFn`.
840 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
841   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
842 
843   LogicalResult
844   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
845                           ConversionPatternRewriter &rewriter) const {
846     auto loc = op.getLoc();
847     auto srcType = op.source().getType().dyn_cast<MemRefType>();
848 
849     MemRefDescriptor srcDesc(adaptor.source());
850 
851     // Compute number of elements.
852     Value numElements = rewriter.create<LLVM::ConstantOp>(
853         loc, getIndexType(), rewriter.getIndexAttr(1));
854     for (int pos = 0; pos < srcType.getRank(); ++pos) {
855       auto size = srcDesc.size(rewriter, loc, pos);
856       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
857     }
858 
859     // Get element size.
860     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
861     // Compute total.
862     Value totalSize =
863         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
864 
865     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
866     Value srcOffset = srcDesc.offset(rewriter, loc);
867     Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
868                                                 srcBasePtr, srcOffset);
869     MemRefDescriptor targetDesc(adaptor.target());
870     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
871     Value targetOffset = targetDesc.offset(rewriter, loc);
872     Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
873                                                    targetBasePtr, targetOffset);
874     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
875         loc, typeConverter->convertType(rewriter.getI1Type()),
876         rewriter.getBoolAttr(false));
877     rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
878                                     isVolatile);
879     rewriter.eraseOp(op);
880 
881     return success();
882   }
883 
884   LogicalResult
885   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
886                              ConversionPatternRewriter &rewriter) const {
887     auto loc = op.getLoc();
888     auto srcType = op.source().getType().cast<BaseMemRefType>();
889     auto targetType = op.target().getType().cast<BaseMemRefType>();
890 
891     // First make sure we have an unranked memref descriptor representation.
892     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
893       auto rank = rewriter.create<LLVM::ConstantOp>(
894           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
895       auto *typeConverter = getTypeConverter();
896       auto ptr =
897           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
898       auto voidPtr =
899           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
900               .getResult();
901       auto unrankedType =
902           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
903       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
904                                             unrankedType,
905                                             ValueRange{rank, voidPtr});
906     };
907 
908     Value unrankedSource = srcType.hasRank()
909                                ? makeUnranked(adaptor.source(), srcType)
910                                : adaptor.source();
911     Value unrankedTarget = targetType.hasRank()
912                                ? makeUnranked(adaptor.target(), targetType)
913                                : adaptor.target();
914 
915     // Now promote the unranked descriptors to the stack.
916     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
917                                                  rewriter.getIndexAttr(1));
918     auto promote = [&](Value desc) {
919       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
920       auto allocated =
921           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
922       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
923       return allocated;
924     };
925 
926     auto sourcePtr = promote(unrankedSource);
927     auto targetPtr = promote(unrankedTarget);
928 
929     unsigned typeSize =
930         mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
931     auto elemSize = rewriter.create<LLVM::ConstantOp>(
932         loc, getIndexType(), rewriter.getIndexAttr(typeSize));
933     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
934         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
935     rewriter.create<LLVM::CallOp>(loc, copyFn,
936                                   ValueRange{elemSize, sourcePtr, targetPtr});
937     rewriter.eraseOp(op);
938 
939     return success();
940   }
941 
942   LogicalResult
943   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
944                   ConversionPatternRewriter &rewriter) const override {
945     auto srcType = op.source().getType().cast<BaseMemRefType>();
946     auto targetType = op.target().getType().cast<BaseMemRefType>();
947 
948     auto isContiguousMemrefType = [](BaseMemRefType type) {
949       auto memrefType = type.dyn_cast<mlir::MemRefType>();
950       // We can use memcpy for memrefs if they have an identity layout or are
951       // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
952       // special case handled by memrefCopy.
953       return memrefType &&
954              (memrefType.getLayout().isIdentity() ||
955               (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
956                isStaticShapeAndContiguousRowMajor(memrefType)));
957     };
958 
959     if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
960       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
961 
962     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
963   }
964 };
965 
966 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
967 /// memref type. In unranked case, the fields are extracted from the underlying
968 /// ranked descriptor.
969 static void extractPointersAndOffset(Location loc,
970                                      ConversionPatternRewriter &rewriter,
971                                      LLVMTypeConverter &typeConverter,
972                                      Value originalOperand,
973                                      Value convertedOperand,
974                                      Value *allocatedPtr, Value *alignedPtr,
975                                      Value *offset = nullptr) {
976   Type operandType = originalOperand.getType();
977   if (operandType.isa<MemRefType>()) {
978     MemRefDescriptor desc(convertedOperand);
979     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
980     *alignedPtr = desc.alignedPtr(rewriter, loc);
981     if (offset != nullptr)
982       *offset = desc.offset(rewriter, loc);
983     return;
984   }
985 
986   unsigned memorySpace =
987       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
988   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
989   Type llvmElementType = typeConverter.convertType(elementType);
990   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
991       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
992 
993   // Extract pointer to the underlying ranked memref descriptor and cast it to
994   // ElemType**.
995   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
996   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
997 
998   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
999       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
1000   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1001       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1002   if (offset != nullptr) {
1003     *offset = UnrankedMemRefDescriptor::offset(
1004         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1005   }
1006 }
1007 
1008 struct MemRefReinterpretCastOpLowering
1009     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1010   using ConvertOpToLLVMPattern<
1011       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1012 
1013   LogicalResult
1014   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1015                   ConversionPatternRewriter &rewriter) const override {
1016     Type srcType = castOp.source().getType();
1017 
1018     Value descriptor;
1019     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1020                                                adaptor, &descriptor)))
1021       return failure();
1022     rewriter.replaceOp(castOp, {descriptor});
1023     return success();
1024   }
1025 
1026 private:
1027   LogicalResult convertSourceMemRefToDescriptor(
1028       ConversionPatternRewriter &rewriter, Type srcType,
1029       memref::ReinterpretCastOp castOp,
1030       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1031     MemRefType targetMemRefType =
1032         castOp.getResult().getType().cast<MemRefType>();
1033     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1034                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1035     if (!llvmTargetDescriptorTy)
1036       return failure();
1037 
1038     // Create descriptor.
1039     Location loc = castOp.getLoc();
1040     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1041 
1042     // Set allocated and aligned pointers.
1043     Value allocatedPtr, alignedPtr;
1044     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1045                              castOp.source(), adaptor.source(), &allocatedPtr,
1046                              &alignedPtr);
1047     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1048     desc.setAlignedPtr(rewriter, loc, alignedPtr);
1049 
1050     // Set offset.
1051     if (castOp.isDynamicOffset(0))
1052       desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
1053     else
1054       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1055 
1056     // Set sizes and strides.
1057     unsigned dynSizeId = 0;
1058     unsigned dynStrideId = 0;
1059     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1060       if (castOp.isDynamicSize(i))
1061         desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
1062       else
1063         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1064 
1065       if (castOp.isDynamicStride(i))
1066         desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
1067       else
1068         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1069     }
1070     *descriptor = desc;
1071     return success();
1072   }
1073 };
1074 
1075 struct MemRefReshapeOpLowering
1076     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1077   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1078 
1079   LogicalResult
1080   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1081                   ConversionPatternRewriter &rewriter) const override {
1082     Type srcType = reshapeOp.source().getType();
1083 
1084     Value descriptor;
1085     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1086                                                adaptor, &descriptor)))
1087       return failure();
1088     rewriter.replaceOp(reshapeOp, {descriptor});
1089     return success();
1090   }
1091 
1092 private:
1093   LogicalResult
1094   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1095                                   Type srcType, memref::ReshapeOp reshapeOp,
1096                                   memref::ReshapeOp::Adaptor adaptor,
1097                                   Value *descriptor) const {
1098     auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
1099     if (shapeMemRefType.hasStaticShape()) {
1100       MemRefType targetMemRefType =
1101           reshapeOp.getResult().getType().cast<MemRefType>();
1102       auto llvmTargetDescriptorTy =
1103           typeConverter->convertType(targetMemRefType)
1104               .dyn_cast_or_null<LLVM::LLVMStructType>();
1105       if (!llvmTargetDescriptorTy)
1106         return failure();
1107 
1108       // Create descriptor.
1109       Location loc = reshapeOp.getLoc();
1110       auto desc =
1111           MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1112 
1113       // Set allocated and aligned pointers.
1114       Value allocatedPtr, alignedPtr;
1115       extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1116                                reshapeOp.source(), adaptor.source(),
1117                                &allocatedPtr, &alignedPtr);
1118       desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1119       desc.setAlignedPtr(rewriter, loc, alignedPtr);
1120 
1121       // Extract the offset and strides from the type.
1122       int64_t offset;
1123       SmallVector<int64_t> strides;
1124       if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1125         return rewriter.notifyMatchFailure(
1126             reshapeOp, "failed to get stride and offset exprs");
1127 
1128       if (!isStaticStrideOrOffset(offset))
1129         return rewriter.notifyMatchFailure(reshapeOp,
1130                                            "dynamic offset is unsupported");
1131       if (!llvm::all_of(strides, isStaticStrideOrOffset))
1132         return rewriter.notifyMatchFailure(reshapeOp,
1133                                            "dynamic strides are unsupported");
1134 
1135       desc.setConstantOffset(rewriter, loc, offset);
1136       for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1137         desc.setConstantSize(rewriter, loc, i, targetMemRefType.getDimSize(i));
1138         desc.setConstantStride(rewriter, loc, i, strides[i]);
1139       }
1140 
1141       *descriptor = desc;
1142       return success();
1143     }
1144 
1145     // The shape is a rank-1 tensor with unknown length.
1146     Location loc = reshapeOp.getLoc();
1147     MemRefDescriptor shapeDesc(adaptor.shape());
1148     Value resultRank = shapeDesc.size(rewriter, loc, 0);
1149 
1150     // Extract address space and element type.
1151     auto targetType =
1152         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1153     unsigned addressSpace = targetType.getMemorySpaceAsInt();
1154     Type elementType = targetType.getElementType();
1155 
1156     // Create the unranked memref descriptor that holds the ranked one. The
1157     // inner descriptor is allocated on stack.
1158     auto targetDesc = UnrankedMemRefDescriptor::undef(
1159         rewriter, loc, typeConverter->convertType(targetType));
1160     targetDesc.setRank(rewriter, loc, resultRank);
1161     SmallVector<Value, 4> sizes;
1162     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1163                                            targetDesc, sizes);
1164     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1165         loc, getVoidPtrType(), sizes.front(), llvm::None);
1166     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1167 
1168     // Extract pointers and offset from the source memref.
1169     Value allocatedPtr, alignedPtr, offset;
1170     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1171                              reshapeOp.source(), adaptor.source(),
1172                              &allocatedPtr, &alignedPtr, &offset);
1173 
1174     // Set pointers and offset.
1175     Type llvmElementType = typeConverter->convertType(elementType);
1176     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1177         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1178     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1179                                               elementPtrPtrType, allocatedPtr);
1180     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1181                                             underlyingDescPtr,
1182                                             elementPtrPtrType, alignedPtr);
1183     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1184                                         underlyingDescPtr, elementPtrPtrType,
1185                                         offset);
1186 
1187     // Use the offset pointer as base for further addressing. Copy over the new
1188     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1189     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1190         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1191         elementPtrPtrType);
1192     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1193         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1194     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1195     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1196     Value resultRankMinusOne =
1197         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1198 
1199     Block *initBlock = rewriter.getInsertionBlock();
1200     Type indexType = getTypeConverter()->getIndexType();
1201     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1202 
1203     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1204                                             {indexType, indexType}, {loc, loc});
1205 
1206     // Move the remaining initBlock ops to condBlock.
1207     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1208     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1209 
1210     rewriter.setInsertionPointToEnd(initBlock);
1211     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1212                                 condBlock);
1213     rewriter.setInsertionPointToStart(condBlock);
1214     Value indexArg = condBlock->getArgument(0);
1215     Value strideArg = condBlock->getArgument(1);
1216 
1217     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1218     Value pred = rewriter.create<LLVM::ICmpOp>(
1219         loc, IntegerType::get(rewriter.getContext(), 1),
1220         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1221 
1222     Block *bodyBlock =
1223         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1224     rewriter.setInsertionPointToStart(bodyBlock);
1225 
1226     // Copy size from shape to descriptor.
1227     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1228     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1229         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1230     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1231     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1232                                       targetSizesBase, indexArg, size);
1233 
1234     // Write stride value and compute next one.
1235     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1236                                         targetStridesBase, indexArg, strideArg);
1237     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1238 
1239     // Decrement loop counter and branch back.
1240     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1241     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1242                                 condBlock);
1243 
1244     Block *remainder =
1245         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1246 
1247     // Hook up the cond exit to the remainder.
1248     rewriter.setInsertionPointToEnd(condBlock);
1249     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1250                                     llvm::None);
1251 
1252     // Reset position to beginning of new remainder block.
1253     rewriter.setInsertionPointToStart(remainder);
1254 
1255     *descriptor = targetDesc;
1256     return success();
1257   }
1258 };
1259 
1260 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1261 /// `Value`s.
1262 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1263                                       Type &llvmIndexType,
1264                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1265   return llvm::to_vector<4>(
1266       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1267         if (auto attr = value.dyn_cast<Attribute>())
1268           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1269         return value.get<Value>();
1270       }));
1271 }
1272 
1273 /// Compute a map that for a given dimension of the expanded type gives the
1274 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1275 /// the `reassocation` maps.
1276 static DenseMap<int64_t, int64_t>
1277 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1278   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1279   for (auto &en : enumerate(reassociation)) {
1280     for (auto dim : en.value())
1281       expandedDimToCollapsedDim[dim] = en.index();
1282   }
1283   return expandedDimToCollapsedDim;
1284 }
1285 
1286 static OpFoldResult
1287 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1288                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1289                          MemRefDescriptor &inDesc,
1290                          ArrayRef<int64_t> inStaticShape,
1291                          ArrayRef<ReassociationIndices> reassocation,
1292                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1293   int64_t outDimSize = outStaticShape[outDimIndex];
1294   if (!ShapedType::isDynamic(outDimSize))
1295     return b.getIndexAttr(outDimSize);
1296 
1297   // Calculate the multiplication of all the out dim sizes except the
1298   // current dim.
1299   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1300   int64_t otherDimSizesMul = 1;
1301   for (auto otherDimIndex : reassocation[inDimIndex]) {
1302     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1303       continue;
1304     int64_t otherDimSize = outStaticShape[otherDimIndex];
1305     assert(!ShapedType::isDynamic(otherDimSize) &&
1306            "single dimension cannot be expanded into multiple dynamic "
1307            "dimensions");
1308     otherDimSizesMul *= otherDimSize;
1309   }
1310 
1311   // outDimSize = inDimSize / otherOutDimSizesMul
1312   int64_t inDimSize = inStaticShape[inDimIndex];
1313   Value inDimSizeDynamic =
1314       ShapedType::isDynamic(inDimSize)
1315           ? inDesc.size(b, loc, inDimIndex)
1316           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1317                                        b.getIndexAttr(inDimSize));
1318   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1319       loc, inDimSizeDynamic,
1320       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1321                                  b.getIndexAttr(otherDimSizesMul)));
1322   return outDimSizeDynamic;
1323 }
1324 
1325 static OpFoldResult getCollapsedOutputDimSize(
1326     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1327     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1328     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1329   if (!ShapedType::isDynamic(outDimSize))
1330     return b.getIndexAttr(outDimSize);
1331 
1332   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1333   Value outDimSizeDynamic = c1;
1334   for (auto inDimIndex : reassocation[outDimIndex]) {
1335     int64_t inDimSize = inStaticShape[inDimIndex];
1336     Value inDimSizeDynamic =
1337         ShapedType::isDynamic(inDimSize)
1338             ? inDesc.size(b, loc, inDimIndex)
1339             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1340                                          b.getIndexAttr(inDimSize));
1341     outDimSizeDynamic =
1342         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1343   }
1344   return outDimSizeDynamic;
1345 }
1346 
1347 static SmallVector<OpFoldResult, 4>
1348 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1349                         ArrayRef<ReassociationIndices> reassociation,
1350                         ArrayRef<int64_t> inStaticShape,
1351                         MemRefDescriptor &inDesc,
1352                         ArrayRef<int64_t> outStaticShape) {
1353   return llvm::to_vector<4>(llvm::map_range(
1354       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1355         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1356                                          outStaticShape[outDimIndex],
1357                                          inStaticShape, inDesc, reassociation);
1358       }));
1359 }
1360 
1361 static SmallVector<OpFoldResult, 4>
1362 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1363                        ArrayRef<ReassociationIndices> reassociation,
1364                        ArrayRef<int64_t> inStaticShape,
1365                        MemRefDescriptor &inDesc,
1366                        ArrayRef<int64_t> outStaticShape) {
1367   DenseMap<int64_t, int64_t> outDimToInDimMap =
1368       getExpandedDimToCollapsedDimMap(reassociation);
1369   return llvm::to_vector<4>(llvm::map_range(
1370       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1371         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1372                                         outStaticShape, inDesc, inStaticShape,
1373                                         reassociation, outDimToInDimMap);
1374       }));
1375 }
1376 
1377 static SmallVector<Value>
1378 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1379                       ArrayRef<ReassociationIndices> reassociation,
1380                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1381                       ArrayRef<int64_t> outStaticShape) {
1382   return outStaticShape.size() < inStaticShape.size()
1383              ? getAsValues(b, loc, llvmIndexType,
1384                            getCollapsedOutputShape(b, loc, llvmIndexType,
1385                                                    reassociation, inStaticShape,
1386                                                    inDesc, outStaticShape))
1387              : getAsValues(b, loc, llvmIndexType,
1388                            getExpandedOutputShape(b, loc, llvmIndexType,
1389                                                   reassociation, inStaticShape,
1390                                                   inDesc, outStaticShape));
1391 }
1392 
1393 static void fillInStridesForExpandedMemDescriptor(
1394     OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
1395     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1396   // See comments for computeExpandedLayoutMap for details on how the strides
1397   // are calculated.
1398   for (auto &en : llvm::enumerate(reassociation)) {
1399     auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
1400     for (auto dstIndex : llvm::reverse(en.value())) {
1401       dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
1402       Value size = dstDesc.size(b, loc, dstIndex);
1403       currentStrideToExpand =
1404           b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
1405     }
1406   }
1407 }
1408 
1409 static void fillInStridesForCollapsedMemDescriptor(
1410     ConversionPatternRewriter &rewriter, Location loc, Operation *op,
1411     TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
1412     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1413   // See comments for computeCollapsedLayoutMap for details on how the strides
1414   // are calculated.
1415   auto srcShape = srcType.getShape();
1416   for (auto &en : llvm::enumerate(reassociation)) {
1417     rewriter.setInsertionPoint(op);
1418     auto dstIndex = en.index();
1419     ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value());
1420     while (srcShape[ref.back()] == 1 && ref.size() > 1)
1421       ref = ref.drop_back();
1422     if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1423       dstDesc.setStride(rewriter, loc, dstIndex,
1424                         srcDesc.stride(rewriter, loc, ref.back()));
1425     } else {
1426       // Iterate over the source strides in reverse order. Skip over the
1427       // dimensions whose size is 1.
1428       // TODO: we should take the minimum stride in the reassociation group
1429       // instead of just the first where the dimension is not 1.
1430       //
1431       // +------------------------------------------------------+
1432       // | curEntry:                                            |
1433       // |   %srcStride = strides[srcIndex]                     |
1434       // |   %neOne = cmp sizes[srcIndex],1                     +--+
1435       // |   cf.cond_br %neOne, continue(%srcStride), nextEntry |  |
1436       // +-------------------------+----------------------------+  |
1437       //                           |                               |
1438       //                           v                               |
1439       //            +-----------------------------+                |
1440       //            | nextEntry:                  |                |
1441       //            |   ...                       +---+            |
1442       //            +--------------+--------------+   |            |
1443       //                           |                  |            |
1444       //                           v                  |            |
1445       //            +-----------------------------+   |            |
1446       //            | nextEntry:                  |   |            |
1447       //            |   ...                       |   |            |
1448       //            +--------------+--------------+   |   +--------+
1449       //                           |                  |   |
1450       //                           v                  v   v
1451       //   +--------------------------------------------------+
1452       //   | continue(%newStride):                            |
1453       //   |   %newMemRefDes = setStride(%newStride,dstIndex) |
1454       //   +--------------------------------------------------+
1455       OpBuilder::InsertionGuard guard(rewriter);
1456       Block *initBlock = rewriter.getInsertionBlock();
1457       Block *continueBlock =
1458           rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
1459       continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
1460       rewriter.setInsertionPointToStart(continueBlock);
1461       dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
1462 
1463       Block *curEntryBlock = initBlock;
1464       Block *nextEntryBlock;
1465       for (auto srcIndex : llvm::reverse(ref)) {
1466         if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
1467           continue;
1468         rewriter.setInsertionPointToEnd(curEntryBlock);
1469         Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
1470         if (srcIndex == ref.front()) {
1471           rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
1472           break;
1473         }
1474         Value one = rewriter.create<LLVM::ConstantOp>(
1475             loc, typeConverter->convertType(rewriter.getI64Type()),
1476             rewriter.getI32IntegerAttr(1));
1477         Value predNeOne = rewriter.create<LLVM::ICmpOp>(
1478             loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
1479             one);
1480         {
1481           OpBuilder::InsertionGuard guard(rewriter);
1482           nextEntryBlock = rewriter.createBlock(
1483               initBlock->getParent(), Region::iterator(continueBlock), {});
1484         }
1485         rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
1486                                         srcStride, nextEntryBlock, llvm::None);
1487         curEntryBlock = nextEntryBlock;
1488       }
1489     }
1490   }
1491 }
1492 
1493 static void fillInDynamicStridesForMemDescriptor(
1494     ConversionPatternRewriter &b, Location loc, Operation *op,
1495     TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
1496     MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
1497     ArrayRef<ReassociationIndices> reassociation) {
1498   if (srcType.getRank() > dstType.getRank())
1499     fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
1500                                            srcDesc, dstDesc, reassociation);
1501   else
1502     fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
1503                                           reassociation);
1504 }
1505 
1506 // ReshapeOp creates a new view descriptor of the proper rank.
1507 // For now, the only conversion supported is for target MemRef with static sizes
1508 // and strides.
1509 template <typename ReshapeOp>
1510 class ReassociatingReshapeOpConversion
1511     : public ConvertOpToLLVMPattern<ReshapeOp> {
1512 public:
1513   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1514   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1515 
1516   LogicalResult
1517   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1518                   ConversionPatternRewriter &rewriter) const override {
1519     MemRefType dstType = reshapeOp.getResultType();
1520     MemRefType srcType = reshapeOp.getSrcType();
1521 
1522     int64_t offset;
1523     SmallVector<int64_t, 4> strides;
1524     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1525       return rewriter.notifyMatchFailure(
1526           reshapeOp, "failed to get stride and offset exprs");
1527     }
1528 
1529     MemRefDescriptor srcDesc(adaptor.src());
1530     Location loc = reshapeOp->getLoc();
1531     auto dstDesc = MemRefDescriptor::undef(
1532         rewriter, loc, this->typeConverter->convertType(dstType));
1533     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1534     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1535     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1536 
1537     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1538     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1539     Type llvmIndexType =
1540         this->typeConverter->convertType(rewriter.getIndexType());
1541     SmallVector<Value> dstShape = getDynamicOutputShape(
1542         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1543         srcStaticShape, srcDesc, dstStaticShape);
1544     for (auto &en : llvm::enumerate(dstShape))
1545       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1546 
1547     if (llvm::all_of(strides, isStaticStrideOrOffset)) {
1548       for (auto &en : llvm::enumerate(strides))
1549         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1550     } else if (srcType.getLayout().isIdentity() &&
1551                dstType.getLayout().isIdentity()) {
1552       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1553                                                    rewriter.getIndexAttr(1));
1554       Value stride = c1;
1555       for (auto dimIndex :
1556            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1557         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1558         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1559       }
1560     } else {
1561       // There could be mixed static/dynamic strides. For simplicity, we
1562       // recompute all strides if there is at least one dynamic stride.
1563       fillInDynamicStridesForMemDescriptor(
1564           rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
1565           srcDesc, dstDesc, reshapeOp.getReassociationIndices());
1566     }
1567     rewriter.replaceOp(reshapeOp, {dstDesc});
1568     return success();
1569   }
1570 };
1571 
1572 /// Conversion pattern that transforms a subview op into:
1573 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1574 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1575 ///      and stride.
1576 /// The subview op is replaced by the descriptor.
1577 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1578   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1579 
1580   LogicalResult
1581   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1582                   ConversionPatternRewriter &rewriter) const override {
1583     auto loc = subViewOp.getLoc();
1584 
1585     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
1586     auto sourceElementTy =
1587         typeConverter->convertType(sourceMemRefType.getElementType());
1588 
1589     auto viewMemRefType = subViewOp.getType();
1590     auto inferredType = memref::SubViewOp::inferResultType(
1591                             subViewOp.getSourceType(),
1592                             extractFromI64ArrayAttr(subViewOp.static_offsets()),
1593                             extractFromI64ArrayAttr(subViewOp.static_sizes()),
1594                             extractFromI64ArrayAttr(subViewOp.static_strides()))
1595                             .cast<MemRefType>();
1596     auto targetElementTy =
1597         typeConverter->convertType(viewMemRefType.getElementType());
1598     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1599     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1600         !LLVM::isCompatibleType(sourceElementTy) ||
1601         !LLVM::isCompatibleType(targetElementTy) ||
1602         !LLVM::isCompatibleType(targetDescTy))
1603       return failure();
1604 
1605     // Extract the offset and strides from the type.
1606     int64_t offset;
1607     SmallVector<int64_t, 4> strides;
1608     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1609     if (failed(successStrides))
1610       return failure();
1611 
1612     // Create the descriptor.
1613     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1614       return failure();
1615     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1616     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1617 
1618     // Copy the buffer pointer from the old descriptor to the new one.
1619     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1620     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1621         loc,
1622         LLVM::LLVMPointerType::get(targetElementTy,
1623                                    viewMemRefType.getMemorySpaceAsInt()),
1624         extracted);
1625     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1626 
1627     // Copy the aligned pointer from the old descriptor to the new one.
1628     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1629     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1630         loc,
1631         LLVM::LLVMPointerType::get(targetElementTy,
1632                                    viewMemRefType.getMemorySpaceAsInt()),
1633         extracted);
1634     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1635 
1636     size_t inferredShapeRank = inferredType.getRank();
1637     size_t resultShapeRank = viewMemRefType.getRank();
1638 
1639     // Extract strides needed to compute offset.
1640     SmallVector<Value, 4> strideValues;
1641     strideValues.reserve(inferredShapeRank);
1642     for (unsigned i = 0; i < inferredShapeRank; ++i)
1643       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1644 
1645     // Offset.
1646     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1647     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1648       targetMemRef.setConstantOffset(rewriter, loc, offset);
1649     } else {
1650       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1651       // `inferredShapeRank` may be larger than the number of offset operands
1652       // because of trailing semantics. In this case, the offset is guaranteed
1653       // to be interpreted as 0 and we can just skip the extra dimensions.
1654       for (unsigned i = 0, e = std::min(inferredShapeRank,
1655                                         subViewOp.getMixedOffsets().size());
1656            i < e; ++i) {
1657         Value offset =
1658             // TODO: need OpFoldResult ODS adaptor to clean this up.
1659             subViewOp.isDynamicOffset(i)
1660                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1661                 : rewriter.create<LLVM::ConstantOp>(
1662                       loc, llvmIndexType,
1663                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1664         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1665         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1666       }
1667       targetMemRef.setOffset(rewriter, loc, baseOffset);
1668     }
1669 
1670     // Update sizes and strides.
1671     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1672     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1673     assert(mixedSizes.size() == mixedStrides.size() &&
1674            "expected sizes and strides of equal length");
1675     llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1676     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1677          i >= 0 && j >= 0; --i) {
1678       if (unusedDims.test(i))
1679         continue;
1680 
1681       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1682       // In this case, the size is guaranteed to be interpreted as Dim and the
1683       // stride as 1.
1684       Value size, stride;
1685       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1686         // If the static size is available, use it directly. This is similar to
1687         // the folding of dim(constant-op) but removes the need for dim to be
1688         // aware of LLVM constants and for this pass to be aware of std
1689         // constants.
1690         int64_t staticSize =
1691             subViewOp.source().getType().cast<MemRefType>().getShape()[i];
1692         if (staticSize != ShapedType::kDynamicSize) {
1693           size = rewriter.create<LLVM::ConstantOp>(
1694               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1695         } else {
1696           Value pos = rewriter.create<LLVM::ConstantOp>(
1697               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1698           Value dim =
1699               rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos);
1700           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1701               loc, llvmIndexType, dim);
1702           size = cast.getResult(0);
1703         }
1704         stride = rewriter.create<LLVM::ConstantOp>(
1705             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1706       } else {
1707         // TODO: need OpFoldResult ODS adaptor to clean this up.
1708         size =
1709             subViewOp.isDynamicSize(i)
1710                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1711                 : rewriter.create<LLVM::ConstantOp>(
1712                       loc, llvmIndexType,
1713                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1714         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1715           stride = rewriter.create<LLVM::ConstantOp>(
1716               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1717         } else {
1718           stride =
1719               subViewOp.isDynamicStride(i)
1720                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1721                   : rewriter.create<LLVM::ConstantOp>(
1722                         loc, llvmIndexType,
1723                         rewriter.getI64IntegerAttr(
1724                             subViewOp.getStaticStride(i)));
1725           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1726         }
1727       }
1728       targetMemRef.setSize(rewriter, loc, j, size);
1729       targetMemRef.setStride(rewriter, loc, j, stride);
1730       j--;
1731     }
1732 
1733     rewriter.replaceOp(subViewOp, {targetMemRef});
1734     return success();
1735   }
1736 };
1737 
1738 /// Conversion pattern that transforms a transpose op into:
1739 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1740 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1741 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1742 ///      and stride. Size and stride are permutations of the original values.
1743 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1744 /// The transpose op is replaced by the alloca'ed pointer.
1745 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1746 public:
1747   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1748 
1749   LogicalResult
1750   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1751                   ConversionPatternRewriter &rewriter) const override {
1752     auto loc = transposeOp.getLoc();
1753     MemRefDescriptor viewMemRef(adaptor.in());
1754 
1755     // No permutation, early exit.
1756     if (transposeOp.permutation().isIdentity())
1757       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1758 
1759     auto targetMemRef = MemRefDescriptor::undef(
1760         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1761 
1762     // Copy the base and aligned pointers from the old descriptor to the new
1763     // one.
1764     targetMemRef.setAllocatedPtr(rewriter, loc,
1765                                  viewMemRef.allocatedPtr(rewriter, loc));
1766     targetMemRef.setAlignedPtr(rewriter, loc,
1767                                viewMemRef.alignedPtr(rewriter, loc));
1768 
1769     // Copy the offset pointer from the old descriptor to the new one.
1770     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1771 
1772     // Iterate over the dimensions and apply size/stride permutation.
1773     for (const auto &en :
1774          llvm::enumerate(transposeOp.permutation().getResults())) {
1775       int sourcePos = en.index();
1776       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1777       targetMemRef.setSize(rewriter, loc, targetPos,
1778                            viewMemRef.size(rewriter, loc, sourcePos));
1779       targetMemRef.setStride(rewriter, loc, targetPos,
1780                              viewMemRef.stride(rewriter, loc, sourcePos));
1781     }
1782 
1783     rewriter.replaceOp(transposeOp, {targetMemRef});
1784     return success();
1785   }
1786 };
1787 
1788 /// Conversion pattern that transforms an op into:
1789 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1790 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1791 ///      and stride.
1792 /// The view op is replaced by the descriptor.
1793 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1794   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1795 
1796   // Build and return the value for the idx^th shape dimension, either by
1797   // returning the constant shape dimension or counting the proper dynamic size.
1798   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1799                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1800                 unsigned idx) const {
1801     assert(idx < shape.size());
1802     if (!ShapedType::isDynamic(shape[idx]))
1803       return createIndexConstant(rewriter, loc, shape[idx]);
1804     // Count the number of dynamic dims in range [0, idx]
1805     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1806       return ShapedType::isDynamic(v);
1807     });
1808     return dynamicSizes[nDynamic];
1809   }
1810 
1811   // Build and return the idx^th stride, either by returning the constant stride
1812   // or by computing the dynamic stride from the current `runningStride` and
1813   // `nextSize`. The caller should keep a running stride and update it with the
1814   // result returned by this function.
1815   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1816                   ArrayRef<int64_t> strides, Value nextSize,
1817                   Value runningStride, unsigned idx) const {
1818     assert(idx < strides.size());
1819     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1820       return createIndexConstant(rewriter, loc, strides[idx]);
1821     if (nextSize)
1822       return runningStride
1823                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1824                  : nextSize;
1825     assert(!runningStride);
1826     return createIndexConstant(rewriter, loc, 1);
1827   }
1828 
1829   LogicalResult
1830   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1831                   ConversionPatternRewriter &rewriter) const override {
1832     auto loc = viewOp.getLoc();
1833 
1834     auto viewMemRefType = viewOp.getType();
1835     auto targetElementTy =
1836         typeConverter->convertType(viewMemRefType.getElementType());
1837     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1838     if (!targetDescTy || !targetElementTy ||
1839         !LLVM::isCompatibleType(targetElementTy) ||
1840         !LLVM::isCompatibleType(targetDescTy))
1841       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1842              failure();
1843 
1844     int64_t offset;
1845     SmallVector<int64_t, 4> strides;
1846     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1847     if (failed(successStrides))
1848       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1849     assert(offset == 0 && "expected offset to be 0");
1850 
1851     // Target memref must be contiguous in memory (innermost stride is 1), or
1852     // empty (special case when at least one of the memref dimensions is 0).
1853     if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1854       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1855              failure();
1856 
1857     // Create the descriptor.
1858     MemRefDescriptor sourceMemRef(adaptor.source());
1859     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1860 
1861     // Field 1: Copy the allocated pointer, used for malloc/free.
1862     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1863     auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
1864     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1865         loc,
1866         LLVM::LLVMPointerType::get(targetElementTy,
1867                                    srcMemRefType.getMemorySpaceAsInt()),
1868         allocatedPtr);
1869     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1870 
1871     // Field 2: Copy the actual aligned pointer to payload.
1872     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1873     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
1874                                               alignedPtr, adaptor.byte_shift());
1875     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1876         loc,
1877         LLVM::LLVMPointerType::get(targetElementTy,
1878                                    srcMemRefType.getMemorySpaceAsInt()),
1879         alignedPtr);
1880     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1881 
1882     // Field 3: The offset in the resulting type must be 0. This is because of
1883     // the type change: an offset on srcType* may not be expressible as an
1884     // offset on dstType*.
1885     targetMemRef.setOffset(rewriter, loc,
1886                            createIndexConstant(rewriter, loc, offset));
1887 
1888     // Early exit for 0-D corner case.
1889     if (viewMemRefType.getRank() == 0)
1890       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1891 
1892     // Fields 4 and 5: Update sizes and strides.
1893     Value stride = nullptr, nextSize = nullptr;
1894     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1895       // Update size.
1896       Value size =
1897           getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
1898       targetMemRef.setSize(rewriter, loc, i, size);
1899       // Update stride.
1900       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1901       targetMemRef.setStride(rewriter, loc, i, stride);
1902       nextSize = size;
1903     }
1904 
1905     rewriter.replaceOp(viewOp, {targetMemRef});
1906     return success();
1907   }
1908 };
1909 
1910 //===----------------------------------------------------------------------===//
1911 // AtomicRMWOpLowering
1912 //===----------------------------------------------------------------------===//
1913 
1914 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1915 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1916 static Optional<LLVM::AtomicBinOp>
1917 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1918   switch (atomicOp.kind()) {
1919   case arith::AtomicRMWKind::addf:
1920     return LLVM::AtomicBinOp::fadd;
1921   case arith::AtomicRMWKind::addi:
1922     return LLVM::AtomicBinOp::add;
1923   case arith::AtomicRMWKind::assign:
1924     return LLVM::AtomicBinOp::xchg;
1925   case arith::AtomicRMWKind::maxs:
1926     return LLVM::AtomicBinOp::max;
1927   case arith::AtomicRMWKind::maxu:
1928     return LLVM::AtomicBinOp::umax;
1929   case arith::AtomicRMWKind::mins:
1930     return LLVM::AtomicBinOp::min;
1931   case arith::AtomicRMWKind::minu:
1932     return LLVM::AtomicBinOp::umin;
1933   case arith::AtomicRMWKind::ori:
1934     return LLVM::AtomicBinOp::_or;
1935   case arith::AtomicRMWKind::andi:
1936     return LLVM::AtomicBinOp::_and;
1937   default:
1938     return llvm::None;
1939   }
1940   llvm_unreachable("Invalid AtomicRMWKind");
1941 }
1942 
1943 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1944   using Base::Base;
1945 
1946   LogicalResult
1947   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1948                   ConversionPatternRewriter &rewriter) const override {
1949     if (failed(match(atomicOp)))
1950       return failure();
1951     auto maybeKind = matchSimpleAtomicOp(atomicOp);
1952     if (!maybeKind)
1953       return failure();
1954     auto resultType = adaptor.value().getType();
1955     auto memRefType = atomicOp.getMemRefType();
1956     auto dataPtr =
1957         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
1958                              adaptor.indices(), rewriter);
1959     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1960         atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
1961         LLVM::AtomicOrdering::acq_rel);
1962     return success();
1963   }
1964 };
1965 
1966 } // namespace
1967 
1968 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
1969                                                   RewritePatternSet &patterns) {
1970   // clang-format off
1971   patterns.add<
1972       AllocaOpLowering,
1973       AllocaScopeOpLowering,
1974       AtomicRMWOpLowering,
1975       AssumeAlignmentOpLowering,
1976       DimOpLowering,
1977       GenericAtomicRMWOpLowering,
1978       GlobalMemrefOpLowering,
1979       GetGlobalMemrefOpLowering,
1980       LoadOpLowering,
1981       MemRefCastOpLowering,
1982       MemRefCopyOpLowering,
1983       MemRefReinterpretCastOpLowering,
1984       MemRefReshapeOpLowering,
1985       PrefetchOpLowering,
1986       RankOpLowering,
1987       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1988       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1989       StoreOpLowering,
1990       SubViewOpLowering,
1991       TransposeOpLowering,
1992       ViewOpLowering>(converter);
1993   // clang-format on
1994   auto allocLowering = converter.getOptions().allocLowering;
1995   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1996     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1997   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1998     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1999 }
2000 
2001 namespace {
2002 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
2003   MemRefToLLVMPass() = default;
2004 
2005   void runOnOperation() override {
2006     Operation *op = getOperation();
2007     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2008     LowerToLLVMOptions options(&getContext(),
2009                                dataLayoutAnalysis.getAtOrAbove(op));
2010     options.allocLowering =
2011         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
2012                          : LowerToLLVMOptions::AllocLowering::Malloc);
2013     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2014       options.overrideIndexBitwidth(indexBitwidth);
2015 
2016     LLVMTypeConverter typeConverter(&getContext(), options,
2017                                     &dataLayoutAnalysis);
2018     RewritePatternSet patterns(&getContext());
2019     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
2020     LLVMConversionTarget target(getContext());
2021     target.addLegalOp<func::FuncOp>();
2022     if (failed(applyPartialConversion(op, target, std::move(patterns))))
2023       signalPassFailure();
2024   }
2025 };
2026 } // namespace
2027 
2028 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
2029   return std::make_unique<MemRefToLLVMPass>();
2030 }
2031