1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 
29 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
30   return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
31 }
32 
33 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
34   AllocOpLowering(LLVMTypeConverter &converter)
35       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
36                                 converter) {}
37 
38   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
39                                           Location loc, Value sizeBytes,
40                                           Operation *op) const override {
41     // Heap allocations.
42     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
43     MemRefType memRefType = allocOp.getType();
44 
45     Value alignment;
46     if (auto alignmentAttr = allocOp.getAlignment()) {
47       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
48     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
49       // In the case where no alignment is specified, we may want to override
50       // `malloc's` behavior. `malloc` typically aligns at the size of the
51       // biggest scalar on a target HW. For non-scalars, use the natural
52       // alignment of the LLVM type given by the LLVM DataLayout.
53       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
54     }
55 
56     if (alignment) {
57       // Adjust the allocation size to consider alignment.
58       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
59     }
60 
61     // Allocate the underlying buffer and store a pointer to it in the MemRef
62     // descriptor.
63     Type elementPtrType = this->getElementPtrType(memRefType);
64     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
65         allocOp->getParentOfType<ModuleOp>(), getIndexType());
66     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
67                                   getVoidPtrType());
68     Value allocatedPtr =
69         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
70 
71     Value alignedPtr = allocatedPtr;
72     if (alignment) {
73       // Compute the aligned type pointer.
74       Value allocatedInt =
75           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
76       Value alignmentInt =
77           createAligned(rewriter, loc, allocatedInt, alignment);
78       alignedPtr =
79           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
80     }
81 
82     return std::make_tuple(allocatedPtr, alignedPtr);
83   }
84 };
85 
86 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
87   AlignedAllocOpLowering(LLVMTypeConverter &converter)
88       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
89                                 converter) {}
90 
91   /// Returns the memref's element size in bytes using the data layout active at
92   /// `op`.
93   // TODO: there are other places where this is used. Expose publicly?
94   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
95     const DataLayout *layout = &defaultLayout;
96     if (const DataLayoutAnalysis *analysis =
97             getTypeConverter()->getDataLayoutAnalysis()) {
98       layout = &analysis->getAbove(op);
99     }
100     Type elementType = memRefType.getElementType();
101     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
102       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
103                                                          *layout);
104     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
105       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
106           memRefElementType, *layout);
107     return layout->getTypeSize(elementType);
108   }
109 
110   /// Returns true if the memref size in bytes is known to be a multiple of
111   /// factor assuming the data layout active at `op`.
112   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
113                               Operation *op) const {
114     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
115     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
116       if (ShapedType::isDynamic(type.getDimSize(i)))
117         continue;
118       sizeDivisor = sizeDivisor * type.getDimSize(i);
119     }
120     return sizeDivisor % factor == 0;
121   }
122 
123   /// Returns the alignment to be used for the allocation call itself.
124   /// aligned_alloc requires the allocation size to be a power of two, and the
125   /// allocation size to be a multiple of alignment,
126   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
127     if (Optional<uint64_t> alignment = allocOp.getAlignment())
128       return *alignment;
129 
130     // Whenever we don't have alignment set, we will use an alignment
131     // consistent with the element type; since the allocation size has to be a
132     // power of two, we will bump to the next power of two if it already isn't.
133     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
134     return std::max(kMinAlignedAllocAlignment,
135                     llvm::PowerOf2Ceil(eltSizeBytes));
136   }
137 
138   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
139                                           Location loc, Value sizeBytes,
140                                           Operation *op) const override {
141     // Heap allocations.
142     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
143     MemRefType memRefType = allocOp.getType();
144     int64_t alignment = getAllocationAlignment(allocOp);
145     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
146 
147     // aligned_alloc requires size to be a multiple of alignment; we will pad
148     // the size to the next multiple if necessary.
149     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
150       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
151 
152     Type elementPtrType = this->getElementPtrType(memRefType);
153     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
154         allocOp->getParentOfType<ModuleOp>(), getIndexType());
155     auto results =
156         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
157                        getVoidPtrType());
158     Value allocatedPtr =
159         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
160 
161     return std::make_tuple(allocatedPtr, allocatedPtr);
162   }
163 
164   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
165   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
166 
167   /// Default layout to use in absence of the corresponding analysis.
168   DataLayout defaultLayout;
169 };
170 
171 // Out of line definition, required till C++17.
172 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
173 
174 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
175   AllocaOpLowering(LLVMTypeConverter &converter)
176       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
177                                 converter) {}
178 
179   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
180   /// is set to null for stack allocations. `accessAlignment` is set if
181   /// alignment is needed post allocation (for eg. in conjunction with malloc).
182   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
183                                           Location loc, Value sizeBytes,
184                                           Operation *op) const override {
185 
186     // With alloca, one gets a pointer to the element type right away.
187     // For stack allocations.
188     auto allocaOp = cast<memref::AllocaOp>(op);
189     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
190 
191     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
192         loc, elementPtrType, sizeBytes,
193         allocaOp.getAlignment() ? *allocaOp.getAlignment() : 0);
194 
195     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
196   }
197 };
198 
199 struct AllocaScopeOpLowering
200     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
201   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
202 
203   LogicalResult
204   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
205                   ConversionPatternRewriter &rewriter) const override {
206     OpBuilder::InsertionGuard guard(rewriter);
207     Location loc = allocaScopeOp.getLoc();
208 
209     // Split the current block before the AllocaScopeOp to create the inlining
210     // point.
211     auto *currentBlock = rewriter.getInsertionBlock();
212     auto *remainingOpsBlock =
213         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
214     Block *continueBlock;
215     if (allocaScopeOp.getNumResults() == 0) {
216       continueBlock = remainingOpsBlock;
217     } else {
218       continueBlock = rewriter.createBlock(
219           remainingOpsBlock, allocaScopeOp.getResultTypes(),
220           SmallVector<Location>(allocaScopeOp->getNumResults(),
221                                 allocaScopeOp.getLoc()));
222       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
223     }
224 
225     // Inline body region.
226     Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
227     Block *afterBody = &allocaScopeOp.getBodyRegion().back();
228     rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
229 
230     // Save stack and then branch into the body of the region.
231     rewriter.setInsertionPointToEnd(currentBlock);
232     auto stackSaveOp =
233         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
234     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
235 
236     // Replace the alloca_scope return with a branch that jumps out of the body.
237     // Stack restore before leaving the body region.
238     rewriter.setInsertionPointToEnd(afterBody);
239     auto returnOp =
240         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
241     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
242         returnOp, returnOp.getResults(), continueBlock);
243 
244     // Insert stack restore before jumping out the body of the region.
245     rewriter.setInsertionPoint(branchOp);
246     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
247 
248     // Replace the op with values return from the body region.
249     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
250 
251     return success();
252   }
253 };
254 
255 struct AssumeAlignmentOpLowering
256     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
257   using ConvertOpToLLVMPattern<
258       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
259 
260   LogicalResult
261   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
262                   ConversionPatternRewriter &rewriter) const override {
263     Value memref = adaptor.getMemref();
264     unsigned alignment = op.getAlignment();
265     auto loc = op.getLoc();
266 
267     MemRefDescriptor memRefDescriptor(memref);
268     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
269 
270     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
271     // the asserted memref.alignedPtr isn't used anywhere else, as the real
272     // users like load/store/views always re-extract memref.alignedPtr as they
273     // get lowered.
274     //
275     // This relies on LLVM's CSE optimization (potentially after SROA), since
276     // after CSE all memref.alignedPtr instances get de-duplicated into the same
277     // pointer SSA value.
278     auto intPtrType =
279         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
280     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
281     Value mask =
282         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
283     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
284     rewriter.create<LLVM::AssumeOp>(
285         loc, rewriter.create<LLVM::ICmpOp>(
286                  loc, LLVM::ICmpPredicate::eq,
287                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
288 
289     rewriter.eraseOp(op);
290     return success();
291   }
292 };
293 
294 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
295 // The memref descriptor being an SSA value, there is no need to clean it up
296 // in any way.
297 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
298   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
299 
300   explicit DeallocOpLowering(LLVMTypeConverter &converter)
301       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
302 
303   LogicalResult
304   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
305                   ConversionPatternRewriter &rewriter) const override {
306     // Insert the `free` declaration if it is not already present.
307     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
308     MemRefDescriptor memref(adaptor.getMemref());
309     Value casted = rewriter.create<LLVM::BitcastOp>(
310         op.getLoc(), getVoidPtrType(),
311         memref.allocatedPtr(rewriter, op.getLoc()));
312     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
313         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
314     return success();
315   }
316 };
317 
318 struct AlignedDeallocOpLowering
319     : public ConvertOpToLLVMPattern<memref::DeallocOp> {
320   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
321 
322   explicit AlignedDeallocOpLowering(LLVMTypeConverter &converter)
323       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
324 
325   LogicalResult
326   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
327                   ConversionPatternRewriter &rewriter) const override {
328     // Insert the `free` declaration if it is not already present.
329     auto freeFunc =
330         LLVM::lookupOrCreateAlignedFreeFn(op->getParentOfType<ModuleOp>());
331     MemRefDescriptor memref(adaptor.memref());
332     Value casted = rewriter.create<LLVM::BitcastOp>(
333         op.getLoc(), getVoidPtrType(),
334         memref.allocatedPtr(rewriter, op.getLoc()));
335     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
336         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
337     return success();
338   }
339 };
340 
341 // A `dim` is converted to a constant for static sizes and to an access to the
342 // size stored in the memref descriptor for dynamic sizes.
343 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
344   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
345 
346   LogicalResult
347   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
348                   ConversionPatternRewriter &rewriter) const override {
349     Type operandType = dimOp.getSource().getType();
350     if (operandType.isa<UnrankedMemRefType>()) {
351       rewriter.replaceOp(
352           dimOp, {extractSizeOfUnrankedMemRef(
353                      operandType, dimOp, adaptor.getOperands(), rewriter)});
354 
355       return success();
356     }
357     if (operandType.isa<MemRefType>()) {
358       rewriter.replaceOp(
359           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
360                                             adaptor.getOperands(), rewriter)});
361       return success();
362     }
363     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
364   }
365 
366 private:
367   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
368                                     OpAdaptor adaptor,
369                                     ConversionPatternRewriter &rewriter) const {
370     Location loc = dimOp.getLoc();
371 
372     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
373     auto scalarMemRefType =
374         MemRefType::get({}, unrankedMemRefType.getElementType());
375     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
376 
377     // Extract pointer to the underlying ranked descriptor and bitcast it to a
378     // memref<element_type> descriptor pointer to minimize the number of GEP
379     // operations.
380     UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
381     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
382     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
383         loc,
384         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
385                                    addressSpace),
386         underlyingRankedDesc);
387 
388     // Get pointer to offset field of memref<element_type> descriptor.
389     Type indexPtrTy = LLVM::LLVMPointerType::get(
390         getTypeConverter()->getIndexType(), addressSpace);
391     Value two = rewriter.create<LLVM::ConstantOp>(
392         loc, typeConverter->convertType(rewriter.getI32Type()),
393         rewriter.getI32IntegerAttr(2));
394     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
395         loc, indexPtrTy, scalarMemRefDescPtr,
396         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
397 
398     // The size value that we have to extract can be obtained using GEPop with
399     // `dimOp.index() + 1` index argument.
400     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
401         loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
402     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
403                                                  ValueRange({idxPlusOne}));
404     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
405   }
406 
407   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
408     if (Optional<int64_t> idx = dimOp.getConstantIndex())
409       return idx;
410 
411     if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
412       return constantOp.getValue()
413           .cast<IntegerAttr>()
414           .getValue()
415           .getSExtValue();
416 
417     return llvm::None;
418   }
419 
420   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
421                                   OpAdaptor adaptor,
422                                   ConversionPatternRewriter &rewriter) const {
423     Location loc = dimOp.getLoc();
424 
425     // Take advantage if index is constant.
426     MemRefType memRefType = operandType.cast<MemRefType>();
427     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
428       int64_t i = *index;
429       if (memRefType.isDynamicDim(i)) {
430         // extract dynamic size from the memref descriptor.
431         MemRefDescriptor descriptor(adaptor.getSource());
432         return descriptor.size(rewriter, loc, i);
433       }
434       // Use constant for static size.
435       int64_t dimSize = memRefType.getDimSize(i);
436       return createIndexConstant(rewriter, loc, dimSize);
437     }
438     Value index = adaptor.getIndex();
439     int64_t rank = memRefType.getRank();
440     MemRefDescriptor memrefDescriptor(adaptor.getSource());
441     return memrefDescriptor.size(rewriter, loc, index, rank);
442   }
443 };
444 
445 /// Common base for load and store operations on MemRefs. Restricts the match
446 /// to supported MemRef types. Provides functionality to emit code accessing a
447 /// specific element of the underlying data buffer.
448 template <typename Derived>
449 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
450   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
451   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
452   using Base = LoadStoreOpLowering<Derived>;
453 
454   LogicalResult match(Derived op) const override {
455     MemRefType type = op.getMemRefType();
456     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
457   }
458 };
459 
460 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
461 /// retried until it succeeds in atomically storing a new value into memory.
462 ///
463 ///      +---------------------------------+
464 ///      |   <code before the AtomicRMWOp> |
465 ///      |   <compute initial %loaded>     |
466 ///      |   cf.br loop(%loaded)              |
467 ///      +---------------------------------+
468 ///             |
469 ///  -------|   |
470 ///  |      v   v
471 ///  |   +--------------------------------+
472 ///  |   | loop(%loaded):                 |
473 ///  |   |   <body contents>              |
474 ///  |   |   %pair = cmpxchg              |
475 ///  |   |   %ok = %pair[0]               |
476 ///  |   |   %new = %pair[1]              |
477 ///  |   |   cf.cond_br %ok, end, loop(%new) |
478 ///  |   +--------------------------------+
479 ///  |          |        |
480 ///  |-----------        |
481 ///                      v
482 ///      +--------------------------------+
483 ///      | end:                           |
484 ///      |   <code after the AtomicRMWOp> |
485 ///      +--------------------------------+
486 ///
487 struct GenericAtomicRMWOpLowering
488     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
489   using Base::Base;
490 
491   LogicalResult
492   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
493                   ConversionPatternRewriter &rewriter) const override {
494     auto loc = atomicOp.getLoc();
495     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
496 
497     // Split the block into initial, loop, and ending parts.
498     auto *initBlock = rewriter.getInsertionBlock();
499     auto *loopBlock = rewriter.createBlock(
500         initBlock->getParent(), std::next(Region::iterator(initBlock)),
501         valueType, loc);
502     auto *endBlock = rewriter.createBlock(
503         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
504 
505     // Operations range to be moved to `endBlock`.
506     auto opsToMoveStart = atomicOp->getIterator();
507     auto opsToMoveEnd = initBlock->back().getIterator();
508 
509     // Compute the loaded value and branch to the loop block.
510     rewriter.setInsertionPointToEnd(initBlock);
511     auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
512     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
513                                         adaptor.getIndices(), rewriter);
514     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
515     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
516 
517     // Prepare the body of the loop block.
518     rewriter.setInsertionPointToStart(loopBlock);
519 
520     // Clone the GenericAtomicRMWOp region and extract the result.
521     auto loopArgument = loopBlock->getArgument(0);
522     BlockAndValueMapping mapping;
523     mapping.map(atomicOp.getCurrentValue(), loopArgument);
524     Block &entryBlock = atomicOp.body().front();
525     for (auto &nestedOp : entryBlock.without_terminator()) {
526       Operation *clone = rewriter.clone(nestedOp, mapping);
527       mapping.map(nestedOp.getResults(), clone->getResults());
528     }
529     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
530 
531     // Prepare the epilog of the loop block.
532     // Append the cmpxchg op to the end of the loop block.
533     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
534     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
535     auto boolType = IntegerType::get(rewriter.getContext(), 1);
536     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
537                                                      {valueType, boolType});
538     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
539         loc, pairType, dataPtr, loopArgument, result, successOrdering,
540         failureOrdering);
541     // Extract the %new_loaded and %ok values from the pair.
542     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
543         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
544     Value ok = rewriter.create<LLVM::ExtractValueOp>(
545         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
546 
547     // Conditionally branch to the end or back to the loop depending on %ok.
548     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
549                                     loopBlock, newLoaded);
550 
551     rewriter.setInsertionPointToEnd(endBlock);
552     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
553                  std::next(opsToMoveEnd), rewriter);
554 
555     // The 'result' of the atomic_rmw op is the newly loaded value.
556     rewriter.replaceOp(atomicOp, {newLoaded});
557 
558     return success();
559   }
560 
561 private:
562   // Clones a segment of ops [start, end) and erases the original.
563   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
564                     Block::iterator start, Block::iterator end,
565                     ConversionPatternRewriter &rewriter) const {
566     BlockAndValueMapping mapping;
567     mapping.map(oldResult, newResult);
568     SmallVector<Operation *, 2> opsToErase;
569     for (auto it = start; it != end; ++it) {
570       rewriter.clone(*it, mapping);
571       opsToErase.push_back(&*it);
572     }
573     for (auto *it : opsToErase)
574       rewriter.eraseOp(it);
575   }
576 };
577 
578 /// Returns the LLVM type of the global variable given the memref type `type`.
579 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
580                                           LLVMTypeConverter &typeConverter) {
581   // LLVM type for a global memref will be a multi-dimension array. For
582   // declarations or uninitialized global memrefs, we can potentially flatten
583   // this to a 1D array. However, for memref.global's with an initial value,
584   // we do not intend to flatten the ElementsAttribute when going from std ->
585   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
586   Type elementType = typeConverter.convertType(type.getElementType());
587   Type arrayTy = elementType;
588   // Shape has the outermost dim at index 0, so need to walk it backwards
589   for (int64_t dim : llvm::reverse(type.getShape()))
590     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
591   return arrayTy;
592 }
593 
594 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
595 struct GlobalMemrefOpLowering
596     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
597   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
598 
599   LogicalResult
600   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
601                   ConversionPatternRewriter &rewriter) const override {
602     MemRefType type = global.getType();
603     if (!isConvertibleAndHasIdentityMaps(type))
604       return failure();
605 
606     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
607 
608     LLVM::Linkage linkage =
609         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
610 
611     Attribute initialValue = nullptr;
612     if (!global.isExternal() && !global.isUninitialized()) {
613       auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
614       initialValue = elementsAttr;
615 
616       // For scalar memrefs, the global variable created is of the element type,
617       // so unpack the elements attribute to extract the value.
618       if (type.getRank() == 0)
619         initialValue = elementsAttr.getSplatValue<Attribute>();
620     }
621 
622     uint64_t alignment = global.getAlignment().value_or(0);
623 
624     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
625         global, arrayTy, global.getConstant(), linkage, global.getSymName(),
626         initialValue, alignment, type.getMemorySpaceAsInt());
627     if (!global.isExternal() && global.isUninitialized()) {
628       Block *blk = new Block();
629       newGlobal.getInitializerRegion().push_back(blk);
630       rewriter.setInsertionPointToStart(blk);
631       Value undef[] = {
632           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
633       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
634     }
635     return success();
636   }
637 };
638 
639 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
640 /// the first element stashed into the descriptor. This reuses
641 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
642 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
643   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
644       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
645                                 converter) {}
646 
647   /// Buffer "allocation" for memref.get_global op is getting the address of
648   /// the global variable referenced.
649   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
650                                           Location loc, Value sizeBytes,
651                                           Operation *op) const override {
652     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
653     MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
654     unsigned memSpace = type.getMemorySpaceAsInt();
655 
656     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
657     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
658         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace),
659         getGlobalOp.getName());
660 
661     // Get the address of the first element in the array by creating a GEP with
662     // the address of the GV as the base, and (rank + 1) number of 0 indices.
663     Type elementType = typeConverter->convertType(type.getElementType());
664     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
665 
666     SmallVector<Value> operands;
667     operands.insert(operands.end(), type.getRank() + 1,
668                     createIndexConstant(rewriter, loc, 0));
669     auto gep =
670         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
671 
672     // We do not expect the memref obtained using `memref.get_global` to be
673     // ever deallocated. Set the allocated pointer to be known bad value to
674     // help debug if that ever happens.
675     auto intPtrType = getIntPtrType(memSpace);
676     Value deadBeefConst =
677         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
678     auto deadBeefPtr =
679         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
680 
681     // Both allocated and aligned pointers are same. We could potentially stash
682     // a nullptr for the allocated pointer since we do not expect any dealloc.
683     return std::make_tuple(deadBeefPtr, gep);
684   }
685 };
686 
687 // Load operation is lowered to obtaining a pointer to the indexed element
688 // and loading it.
689 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
690   using Base::Base;
691 
692   LogicalResult
693   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
694                   ConversionPatternRewriter &rewriter) const override {
695     auto type = loadOp.getMemRefType();
696 
697     Value dataPtr =
698         getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
699                              adaptor.getIndices(), rewriter);
700     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
701     return success();
702   }
703 };
704 
705 // Store operation is lowered to obtaining a pointer to the indexed element,
706 // and storing the given value to it.
707 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
708   using Base::Base;
709 
710   LogicalResult
711   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
712                   ConversionPatternRewriter &rewriter) const override {
713     auto type = op.getMemRefType();
714 
715     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
716                                          adaptor.getIndices(), rewriter);
717     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr);
718     return success();
719   }
720 };
721 
722 // The prefetch operation is lowered in a way similar to the load operation
723 // except that the llvm.prefetch operation is used for replacement.
724 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
725   using Base::Base;
726 
727   LogicalResult
728   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
729                   ConversionPatternRewriter &rewriter) const override {
730     auto type = prefetchOp.getMemRefType();
731     auto loc = prefetchOp.getLoc();
732 
733     Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
734                                          adaptor.getIndices(), rewriter);
735 
736     // Replace with llvm.prefetch.
737     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
738     auto isWrite = rewriter.create<LLVM::ConstantOp>(
739         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()));
740     auto localityHint = rewriter.create<LLVM::ConstantOp>(
741         loc, llvmI32Type,
742         rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint()));
743     auto isData = rewriter.create<LLVM::ConstantOp>(
744         loc, llvmI32Type,
745         rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache()));
746 
747     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
748                                                 localityHint, isData);
749     return success();
750   }
751 };
752 
753 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
754   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
755 
756   LogicalResult
757   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
758                   ConversionPatternRewriter &rewriter) const override {
759     Location loc = op.getLoc();
760     Type operandType = op.getMemref().getType();
761     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
762       UnrankedMemRefDescriptor desc(adaptor.getMemref());
763       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
764       return success();
765     }
766     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
767       rewriter.replaceOp(
768           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
769       return success();
770     }
771     return failure();
772   }
773 };
774 
775 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
776   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
777 
778   LogicalResult match(memref::CastOp memRefCastOp) const override {
779     Type srcType = memRefCastOp.getOperand().getType();
780     Type dstType = memRefCastOp.getType();
781 
782     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
783     // used for type erasure. For now they must preserve underlying element type
784     // and require source and result type to have the same rank. Therefore,
785     // perform a sanity check that the underlying structs are the same. Once op
786     // semantics are relaxed we can revisit.
787     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
788       return success(typeConverter->convertType(srcType) ==
789                      typeConverter->convertType(dstType));
790 
791     // At least one of the operands is unranked type
792     assert(srcType.isa<UnrankedMemRefType>() ||
793            dstType.isa<UnrankedMemRefType>());
794 
795     // Unranked to unranked cast is disallowed
796     return !(srcType.isa<UnrankedMemRefType>() &&
797              dstType.isa<UnrankedMemRefType>())
798                ? success()
799                : failure();
800   }
801 
802   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
803                ConversionPatternRewriter &rewriter) const override {
804     auto srcType = memRefCastOp.getOperand().getType();
805     auto dstType = memRefCastOp.getType();
806     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
807     auto loc = memRefCastOp.getLoc();
808 
809     // For ranked/ranked case, just keep the original descriptor.
810     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
811       return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
812 
813     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
814       // Casting ranked to unranked memref type
815       // Set the rank in the destination from the memref type
816       // Allocate space on the stack and copy the src memref descriptor
817       // Set the ptr in the destination to the stack space
818       auto srcMemRefType = srcType.cast<MemRefType>();
819       int64_t rank = srcMemRefType.getRank();
820       // ptr = AllocaOp sizeof(MemRefDescriptor)
821       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
822           loc, adaptor.getSource(), rewriter);
823       // voidptr = BitCastOp srcType* to void*
824       auto voidPtr =
825           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
826               .getResult();
827       // rank = ConstantOp srcRank
828       auto rankVal = rewriter.create<LLVM::ConstantOp>(
829           loc, getIndexType(), rewriter.getIndexAttr(rank));
830       // undef = UndefOp
831       UnrankedMemRefDescriptor memRefDesc =
832           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
833       // d1 = InsertValueOp undef, rank, 0
834       memRefDesc.setRank(rewriter, loc, rankVal);
835       // d2 = InsertValueOp d1, voidptr, 1
836       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
837       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
838 
839     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
840       // Casting from unranked type to ranked.
841       // The operation is assumed to be doing a correct cast. If the destination
842       // type mismatches the unranked the type, it is undefined behavior.
843       UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
844       // ptr = ExtractValueOp src, 1
845       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
846       // castPtr = BitCastOp i8* to structTy*
847       auto castPtr =
848           rewriter
849               .create<LLVM::BitcastOp>(
850                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
851               .getResult();
852       // struct = LoadOp castPtr
853       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
854       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
855     } else {
856       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
857     }
858   }
859 };
860 
861 /// Pattern to lower a `memref.copy` to llvm.
862 ///
863 /// For memrefs with identity layouts, the copy is lowered to the llvm
864 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
865 /// to the generic `MemrefCopyFn`.
866 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
867   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
868 
869   LogicalResult
870   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
871                           ConversionPatternRewriter &rewriter) const {
872     auto loc = op.getLoc();
873     auto srcType = op.getSource().getType().dyn_cast<MemRefType>();
874 
875     MemRefDescriptor srcDesc(adaptor.getSource());
876 
877     // Compute number of elements.
878     Value numElements = rewriter.create<LLVM::ConstantOp>(
879         loc, getIndexType(), rewriter.getIndexAttr(1));
880     for (int pos = 0; pos < srcType.getRank(); ++pos) {
881       auto size = srcDesc.size(rewriter, loc, pos);
882       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
883     }
884 
885     // Get element size.
886     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
887     // Compute total.
888     Value totalSize =
889         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
890 
891     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
892     Value srcOffset = srcDesc.offset(rewriter, loc);
893     Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
894                                                 srcBasePtr, srcOffset);
895     MemRefDescriptor targetDesc(adaptor.getTarget());
896     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
897     Value targetOffset = targetDesc.offset(rewriter, loc);
898     Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
899                                                    targetBasePtr, targetOffset);
900     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
901         loc, typeConverter->convertType(rewriter.getI1Type()),
902         rewriter.getBoolAttr(false));
903     rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
904                                     isVolatile);
905     rewriter.eraseOp(op);
906 
907     return success();
908   }
909 
910   LogicalResult
911   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
912                              ConversionPatternRewriter &rewriter) const {
913     auto loc = op.getLoc();
914     auto srcType = op.getSource().getType().cast<BaseMemRefType>();
915     auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
916 
917     // First make sure we have an unranked memref descriptor representation.
918     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
919       auto rank = rewriter.create<LLVM::ConstantOp>(
920           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
921       auto *typeConverter = getTypeConverter();
922       auto ptr =
923           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
924       auto voidPtr =
925           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
926               .getResult();
927       auto unrankedType =
928           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
929       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
930                                             unrankedType,
931                                             ValueRange{rank, voidPtr});
932     };
933 
934     Value unrankedSource = srcType.hasRank()
935                                ? makeUnranked(adaptor.getSource(), srcType)
936                                : adaptor.getSource();
937     Value unrankedTarget = targetType.hasRank()
938                                ? makeUnranked(adaptor.getTarget(), targetType)
939                                : adaptor.getTarget();
940 
941     // Now promote the unranked descriptors to the stack.
942     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
943                                                  rewriter.getIndexAttr(1));
944     auto promote = [&](Value desc) {
945       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
946       auto allocated =
947           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
948       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
949       return allocated;
950     };
951 
952     auto sourcePtr = promote(unrankedSource);
953     auto targetPtr = promote(unrankedTarget);
954 
955     unsigned typeSize =
956         mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
957     auto elemSize = rewriter.create<LLVM::ConstantOp>(
958         loc, getIndexType(), rewriter.getIndexAttr(typeSize));
959     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
960         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
961     rewriter.create<LLVM::CallOp>(loc, copyFn,
962                                   ValueRange{elemSize, sourcePtr, targetPtr});
963     rewriter.eraseOp(op);
964 
965     return success();
966   }
967 
968   LogicalResult
969   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
970                   ConversionPatternRewriter &rewriter) const override {
971     auto srcType = op.getSource().getType().cast<BaseMemRefType>();
972     auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
973 
974     auto isContiguousMemrefType = [](BaseMemRefType type) {
975       auto memrefType = type.dyn_cast<mlir::MemRefType>();
976       // We can use memcpy for memrefs if they have an identity layout or are
977       // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
978       // special case handled by memrefCopy.
979       return memrefType &&
980              (memrefType.getLayout().isIdentity() ||
981               (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
982                isStaticShapeAndContiguousRowMajor(memrefType)));
983     };
984 
985     if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
986       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
987 
988     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
989   }
990 };
991 
992 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
993 /// memref type. In unranked case, the fields are extracted from the underlying
994 /// ranked descriptor.
995 static void extractPointersAndOffset(Location loc,
996                                      ConversionPatternRewriter &rewriter,
997                                      LLVMTypeConverter &typeConverter,
998                                      Value originalOperand,
999                                      Value convertedOperand,
1000                                      Value *allocatedPtr, Value *alignedPtr,
1001                                      Value *offset = nullptr) {
1002   Type operandType = originalOperand.getType();
1003   if (operandType.isa<MemRefType>()) {
1004     MemRefDescriptor desc(convertedOperand);
1005     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1006     *alignedPtr = desc.alignedPtr(rewriter, loc);
1007     if (offset != nullptr)
1008       *offset = desc.offset(rewriter, loc);
1009     return;
1010   }
1011 
1012   unsigned memorySpace =
1013       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
1014   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
1015   Type llvmElementType = typeConverter.convertType(elementType);
1016   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
1017       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
1018 
1019   // Extract pointer to the underlying ranked memref descriptor and cast it to
1020   // ElemType**.
1021   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1022   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1023 
1024   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
1025       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
1026   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1027       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1028   if (offset != nullptr) {
1029     *offset = UnrankedMemRefDescriptor::offset(
1030         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1031   }
1032 }
1033 
1034 struct MemRefReinterpretCastOpLowering
1035     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1036   using ConvertOpToLLVMPattern<
1037       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1038 
1039   LogicalResult
1040   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1041                   ConversionPatternRewriter &rewriter) const override {
1042     Type srcType = castOp.getSource().getType();
1043 
1044     Value descriptor;
1045     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1046                                                adaptor, &descriptor)))
1047       return failure();
1048     rewriter.replaceOp(castOp, {descriptor});
1049     return success();
1050   }
1051 
1052 private:
1053   LogicalResult convertSourceMemRefToDescriptor(
1054       ConversionPatternRewriter &rewriter, Type srcType,
1055       memref::ReinterpretCastOp castOp,
1056       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1057     MemRefType targetMemRefType =
1058         castOp.getResult().getType().cast<MemRefType>();
1059     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1060                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1061     if (!llvmTargetDescriptorTy)
1062       return failure();
1063 
1064     // Create descriptor.
1065     Location loc = castOp.getLoc();
1066     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1067 
1068     // Set allocated and aligned pointers.
1069     Value allocatedPtr, alignedPtr;
1070     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1071                              castOp.getSource(), adaptor.getSource(),
1072                              &allocatedPtr, &alignedPtr);
1073     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1074     desc.setAlignedPtr(rewriter, loc, alignedPtr);
1075 
1076     // Set offset.
1077     if (castOp.isDynamicOffset(0))
1078       desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1079     else
1080       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1081 
1082     // Set sizes and strides.
1083     unsigned dynSizeId = 0;
1084     unsigned dynStrideId = 0;
1085     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1086       if (castOp.isDynamicSize(i))
1087         desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1088       else
1089         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1090 
1091       if (castOp.isDynamicStride(i))
1092         desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1093       else
1094         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1095     }
1096     *descriptor = desc;
1097     return success();
1098   }
1099 };
1100 
1101 struct MemRefReshapeOpLowering
1102     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1103   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1104 
1105   LogicalResult
1106   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1107                   ConversionPatternRewriter &rewriter) const override {
1108     Type srcType = reshapeOp.getSource().getType();
1109 
1110     Value descriptor;
1111     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1112                                                adaptor, &descriptor)))
1113       return failure();
1114     rewriter.replaceOp(reshapeOp, {descriptor});
1115     return success();
1116   }
1117 
1118 private:
1119   LogicalResult
1120   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1121                                   Type srcType, memref::ReshapeOp reshapeOp,
1122                                   memref::ReshapeOp::Adaptor adaptor,
1123                                   Value *descriptor) const {
1124     auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>();
1125     if (shapeMemRefType.hasStaticShape()) {
1126       MemRefType targetMemRefType =
1127           reshapeOp.getResult().getType().cast<MemRefType>();
1128       auto llvmTargetDescriptorTy =
1129           typeConverter->convertType(targetMemRefType)
1130               .dyn_cast_or_null<LLVM::LLVMStructType>();
1131       if (!llvmTargetDescriptorTy)
1132         return failure();
1133 
1134       // Create descriptor.
1135       Location loc = reshapeOp.getLoc();
1136       auto desc =
1137           MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1138 
1139       // Set allocated and aligned pointers.
1140       Value allocatedPtr, alignedPtr;
1141       extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1142                                reshapeOp.getSource(), adaptor.getSource(),
1143                                &allocatedPtr, &alignedPtr);
1144       desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1145       desc.setAlignedPtr(rewriter, loc, alignedPtr);
1146 
1147       // Extract the offset and strides from the type.
1148       int64_t offset;
1149       SmallVector<int64_t> strides;
1150       if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1151         return rewriter.notifyMatchFailure(
1152             reshapeOp, "failed to get stride and offset exprs");
1153 
1154       if (!isStaticStrideOrOffset(offset))
1155         return rewriter.notifyMatchFailure(reshapeOp,
1156                                            "dynamic offset is unsupported");
1157 
1158       desc.setConstantOffset(rewriter, loc, offset);
1159 
1160       assert(targetMemRefType.getLayout().isIdentity() &&
1161              "Identity layout map is a precondition of a valid reshape op");
1162 
1163       Value stride = nullptr;
1164       int64_t targetRank = targetMemRefType.getRank();
1165       for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1166         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1167           // If the stride for this dimension is dynamic, then use the product
1168           // of the sizes of the inner dimensions.
1169           stride = createIndexConstant(rewriter, loc, strides[i]);
1170         } else if (!stride) {
1171           // `stride` is null only in the first iteration of the loop.  However,
1172           // since the target memref has an identity layout, we can safely set
1173           // the innermost stride to 1.
1174           stride = createIndexConstant(rewriter, loc, 1);
1175         }
1176 
1177         Value dimSize;
1178         int64_t size = targetMemRefType.getDimSize(i);
1179         // If the size of this dimension is dynamic, then load it at runtime
1180         // from the shape operand.
1181         if (!ShapedType::isDynamic(size)) {
1182           dimSize = createIndexConstant(rewriter, loc, size);
1183         } else {
1184           Value shapeOp = reshapeOp.getShape();
1185           Value index = createIndexConstant(rewriter, loc, i);
1186           dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1187         }
1188 
1189         desc.setSize(rewriter, loc, i, dimSize);
1190         desc.setStride(rewriter, loc, i, stride);
1191 
1192         // Prepare the stride value for the next dimension.
1193         stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1194       }
1195 
1196       *descriptor = desc;
1197       return success();
1198     }
1199 
1200     // The shape is a rank-1 tensor with unknown length.
1201     Location loc = reshapeOp.getLoc();
1202     MemRefDescriptor shapeDesc(adaptor.getShape());
1203     Value resultRank = shapeDesc.size(rewriter, loc, 0);
1204 
1205     // Extract address space and element type.
1206     auto targetType =
1207         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1208     unsigned addressSpace = targetType.getMemorySpaceAsInt();
1209     Type elementType = targetType.getElementType();
1210 
1211     // Create the unranked memref descriptor that holds the ranked one. The
1212     // inner descriptor is allocated on stack.
1213     auto targetDesc = UnrankedMemRefDescriptor::undef(
1214         rewriter, loc, typeConverter->convertType(targetType));
1215     targetDesc.setRank(rewriter, loc, resultRank);
1216     SmallVector<Value, 4> sizes;
1217     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1218                                            targetDesc, sizes);
1219     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1220         loc, getVoidPtrType(), sizes.front(), llvm::None);
1221     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1222 
1223     // Extract pointers and offset from the source memref.
1224     Value allocatedPtr, alignedPtr, offset;
1225     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1226                              reshapeOp.getSource(), adaptor.getSource(),
1227                              &allocatedPtr, &alignedPtr, &offset);
1228 
1229     // Set pointers and offset.
1230     Type llvmElementType = typeConverter->convertType(elementType);
1231     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1232         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1233     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1234                                               elementPtrPtrType, allocatedPtr);
1235     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1236                                             underlyingDescPtr,
1237                                             elementPtrPtrType, alignedPtr);
1238     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1239                                         underlyingDescPtr, elementPtrPtrType,
1240                                         offset);
1241 
1242     // Use the offset pointer as base for further addressing. Copy over the new
1243     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1244     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1245         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1246         elementPtrPtrType);
1247     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1248         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1249     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1250     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1251     Value resultRankMinusOne =
1252         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1253 
1254     Block *initBlock = rewriter.getInsertionBlock();
1255     Type indexType = getTypeConverter()->getIndexType();
1256     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1257 
1258     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1259                                             {indexType, indexType}, {loc, loc});
1260 
1261     // Move the remaining initBlock ops to condBlock.
1262     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1263     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1264 
1265     rewriter.setInsertionPointToEnd(initBlock);
1266     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1267                                 condBlock);
1268     rewriter.setInsertionPointToStart(condBlock);
1269     Value indexArg = condBlock->getArgument(0);
1270     Value strideArg = condBlock->getArgument(1);
1271 
1272     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1273     Value pred = rewriter.create<LLVM::ICmpOp>(
1274         loc, IntegerType::get(rewriter.getContext(), 1),
1275         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1276 
1277     Block *bodyBlock =
1278         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1279     rewriter.setInsertionPointToStart(bodyBlock);
1280 
1281     // Copy size from shape to descriptor.
1282     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1283     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1284         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1285     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1286     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1287                                       targetSizesBase, indexArg, size);
1288 
1289     // Write stride value and compute next one.
1290     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1291                                         targetStridesBase, indexArg, strideArg);
1292     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1293 
1294     // Decrement loop counter and branch back.
1295     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1296     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1297                                 condBlock);
1298 
1299     Block *remainder =
1300         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1301 
1302     // Hook up the cond exit to the remainder.
1303     rewriter.setInsertionPointToEnd(condBlock);
1304     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1305                                     llvm::None);
1306 
1307     // Reset position to beginning of new remainder block.
1308     rewriter.setInsertionPointToStart(remainder);
1309 
1310     *descriptor = targetDesc;
1311     return success();
1312   }
1313 };
1314 
1315 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1316 /// `Value`s.
1317 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1318                                       Type &llvmIndexType,
1319                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1320   return llvm::to_vector<4>(
1321       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1322         if (auto attr = value.dyn_cast<Attribute>())
1323           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1324         return value.get<Value>();
1325       }));
1326 }
1327 
1328 /// Compute a map that for a given dimension of the expanded type gives the
1329 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1330 /// the `reassocation` maps.
1331 static DenseMap<int64_t, int64_t>
1332 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1333   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1334   for (auto &en : enumerate(reassociation)) {
1335     for (auto dim : en.value())
1336       expandedDimToCollapsedDim[dim] = en.index();
1337   }
1338   return expandedDimToCollapsedDim;
1339 }
1340 
1341 static OpFoldResult
1342 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1343                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1344                          MemRefDescriptor &inDesc,
1345                          ArrayRef<int64_t> inStaticShape,
1346                          ArrayRef<ReassociationIndices> reassocation,
1347                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1348   int64_t outDimSize = outStaticShape[outDimIndex];
1349   if (!ShapedType::isDynamic(outDimSize))
1350     return b.getIndexAttr(outDimSize);
1351 
1352   // Calculate the multiplication of all the out dim sizes except the
1353   // current dim.
1354   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1355   int64_t otherDimSizesMul = 1;
1356   for (auto otherDimIndex : reassocation[inDimIndex]) {
1357     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1358       continue;
1359     int64_t otherDimSize = outStaticShape[otherDimIndex];
1360     assert(!ShapedType::isDynamic(otherDimSize) &&
1361            "single dimension cannot be expanded into multiple dynamic "
1362            "dimensions");
1363     otherDimSizesMul *= otherDimSize;
1364   }
1365 
1366   // outDimSize = inDimSize / otherOutDimSizesMul
1367   int64_t inDimSize = inStaticShape[inDimIndex];
1368   Value inDimSizeDynamic =
1369       ShapedType::isDynamic(inDimSize)
1370           ? inDesc.size(b, loc, inDimIndex)
1371           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1372                                        b.getIndexAttr(inDimSize));
1373   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1374       loc, inDimSizeDynamic,
1375       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1376                                  b.getIndexAttr(otherDimSizesMul)));
1377   return outDimSizeDynamic;
1378 }
1379 
1380 static OpFoldResult getCollapsedOutputDimSize(
1381     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1382     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1383     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1384   if (!ShapedType::isDynamic(outDimSize))
1385     return b.getIndexAttr(outDimSize);
1386 
1387   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1388   Value outDimSizeDynamic = c1;
1389   for (auto inDimIndex : reassocation[outDimIndex]) {
1390     int64_t inDimSize = inStaticShape[inDimIndex];
1391     Value inDimSizeDynamic =
1392         ShapedType::isDynamic(inDimSize)
1393             ? inDesc.size(b, loc, inDimIndex)
1394             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1395                                          b.getIndexAttr(inDimSize));
1396     outDimSizeDynamic =
1397         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1398   }
1399   return outDimSizeDynamic;
1400 }
1401 
1402 static SmallVector<OpFoldResult, 4>
1403 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1404                         ArrayRef<ReassociationIndices> reassociation,
1405                         ArrayRef<int64_t> inStaticShape,
1406                         MemRefDescriptor &inDesc,
1407                         ArrayRef<int64_t> outStaticShape) {
1408   return llvm::to_vector<4>(llvm::map_range(
1409       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1410         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1411                                          outStaticShape[outDimIndex],
1412                                          inStaticShape, inDesc, reassociation);
1413       }));
1414 }
1415 
1416 static SmallVector<OpFoldResult, 4>
1417 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1418                        ArrayRef<ReassociationIndices> reassociation,
1419                        ArrayRef<int64_t> inStaticShape,
1420                        MemRefDescriptor &inDesc,
1421                        ArrayRef<int64_t> outStaticShape) {
1422   DenseMap<int64_t, int64_t> outDimToInDimMap =
1423       getExpandedDimToCollapsedDimMap(reassociation);
1424   return llvm::to_vector<4>(llvm::map_range(
1425       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1426         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1427                                         outStaticShape, inDesc, inStaticShape,
1428                                         reassociation, outDimToInDimMap);
1429       }));
1430 }
1431 
1432 static SmallVector<Value>
1433 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1434                       ArrayRef<ReassociationIndices> reassociation,
1435                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1436                       ArrayRef<int64_t> outStaticShape) {
1437   return outStaticShape.size() < inStaticShape.size()
1438              ? getAsValues(b, loc, llvmIndexType,
1439                            getCollapsedOutputShape(b, loc, llvmIndexType,
1440                                                    reassociation, inStaticShape,
1441                                                    inDesc, outStaticShape))
1442              : getAsValues(b, loc, llvmIndexType,
1443                            getExpandedOutputShape(b, loc, llvmIndexType,
1444                                                   reassociation, inStaticShape,
1445                                                   inDesc, outStaticShape));
1446 }
1447 
1448 static void fillInStridesForExpandedMemDescriptor(
1449     OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
1450     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1451   // See comments for computeExpandedLayoutMap for details on how the strides
1452   // are calculated.
1453   for (auto &en : llvm::enumerate(reassociation)) {
1454     auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
1455     for (auto dstIndex : llvm::reverse(en.value())) {
1456       dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
1457       Value size = dstDesc.size(b, loc, dstIndex);
1458       currentStrideToExpand =
1459           b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
1460     }
1461   }
1462 }
1463 
1464 static void fillInStridesForCollapsedMemDescriptor(
1465     ConversionPatternRewriter &rewriter, Location loc, Operation *op,
1466     TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
1467     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1468   // See comments for computeCollapsedLayoutMap for details on how the strides
1469   // are calculated.
1470   auto srcShape = srcType.getShape();
1471   for (auto &en : llvm::enumerate(reassociation)) {
1472     rewriter.setInsertionPoint(op);
1473     auto dstIndex = en.index();
1474     ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value());
1475     while (srcShape[ref.back()] == 1 && ref.size() > 1)
1476       ref = ref.drop_back();
1477     if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1478       dstDesc.setStride(rewriter, loc, dstIndex,
1479                         srcDesc.stride(rewriter, loc, ref.back()));
1480     } else {
1481       // Iterate over the source strides in reverse order. Skip over the
1482       // dimensions whose size is 1.
1483       // TODO: we should take the minimum stride in the reassociation group
1484       // instead of just the first where the dimension is not 1.
1485       //
1486       // +------------------------------------------------------+
1487       // | curEntry:                                            |
1488       // |   %srcStride = strides[srcIndex]                     |
1489       // |   %neOne = cmp sizes[srcIndex],1                     +--+
1490       // |   cf.cond_br %neOne, continue(%srcStride), nextEntry |  |
1491       // +-------------------------+----------------------------+  |
1492       //                           |                               |
1493       //                           v                               |
1494       //            +-----------------------------+                |
1495       //            | nextEntry:                  |                |
1496       //            |   ...                       +---+            |
1497       //            +--------------+--------------+   |            |
1498       //                           |                  |            |
1499       //                           v                  |            |
1500       //            +-----------------------------+   |            |
1501       //            | nextEntry:                  |   |            |
1502       //            |   ...                       |   |            |
1503       //            +--------------+--------------+   |   +--------+
1504       //                           |                  |   |
1505       //                           v                  v   v
1506       //   +--------------------------------------------------+
1507       //   | continue(%newStride):                            |
1508       //   |   %newMemRefDes = setStride(%newStride,dstIndex) |
1509       //   +--------------------------------------------------+
1510       OpBuilder::InsertionGuard guard(rewriter);
1511       Block *initBlock = rewriter.getInsertionBlock();
1512       Block *continueBlock =
1513           rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
1514       continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
1515       rewriter.setInsertionPointToStart(continueBlock);
1516       dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
1517 
1518       Block *curEntryBlock = initBlock;
1519       Block *nextEntryBlock;
1520       for (auto srcIndex : llvm::reverse(ref)) {
1521         if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
1522           continue;
1523         rewriter.setInsertionPointToEnd(curEntryBlock);
1524         Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
1525         if (srcIndex == ref.front()) {
1526           rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
1527           break;
1528         }
1529         Value one = rewriter.create<LLVM::ConstantOp>(
1530             loc, typeConverter->convertType(rewriter.getI64Type()),
1531             rewriter.getI32IntegerAttr(1));
1532         Value predNeOne = rewriter.create<LLVM::ICmpOp>(
1533             loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
1534             one);
1535         {
1536           OpBuilder::InsertionGuard guard(rewriter);
1537           nextEntryBlock = rewriter.createBlock(
1538               initBlock->getParent(), Region::iterator(continueBlock), {});
1539         }
1540         rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
1541                                         srcStride, nextEntryBlock, llvm::None);
1542         curEntryBlock = nextEntryBlock;
1543       }
1544     }
1545   }
1546 }
1547 
1548 static void fillInDynamicStridesForMemDescriptor(
1549     ConversionPatternRewriter &b, Location loc, Operation *op,
1550     TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
1551     MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
1552     ArrayRef<ReassociationIndices> reassociation) {
1553   if (srcType.getRank() > dstType.getRank())
1554     fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
1555                                            srcDesc, dstDesc, reassociation);
1556   else
1557     fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
1558                                           reassociation);
1559 }
1560 
1561 // ReshapeOp creates a new view descriptor of the proper rank.
1562 // For now, the only conversion supported is for target MemRef with static sizes
1563 // and strides.
1564 template <typename ReshapeOp>
1565 class ReassociatingReshapeOpConversion
1566     : public ConvertOpToLLVMPattern<ReshapeOp> {
1567 public:
1568   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1569   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1570 
1571   LogicalResult
1572   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1573                   ConversionPatternRewriter &rewriter) const override {
1574     MemRefType dstType = reshapeOp.getResultType();
1575     MemRefType srcType = reshapeOp.getSrcType();
1576 
1577     int64_t offset;
1578     SmallVector<int64_t, 4> strides;
1579     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1580       return rewriter.notifyMatchFailure(
1581           reshapeOp, "failed to get stride and offset exprs");
1582     }
1583 
1584     MemRefDescriptor srcDesc(adaptor.getSrc());
1585     Location loc = reshapeOp->getLoc();
1586     auto dstDesc = MemRefDescriptor::undef(
1587         rewriter, loc, this->typeConverter->convertType(dstType));
1588     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1589     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1590     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1591 
1592     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1593     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1594     Type llvmIndexType =
1595         this->typeConverter->convertType(rewriter.getIndexType());
1596     SmallVector<Value> dstShape = getDynamicOutputShape(
1597         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1598         srcStaticShape, srcDesc, dstStaticShape);
1599     for (auto &en : llvm::enumerate(dstShape))
1600       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1601 
1602     if (llvm::all_of(strides, isStaticStrideOrOffset)) {
1603       for (auto &en : llvm::enumerate(strides))
1604         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1605     } else if (srcType.getLayout().isIdentity() &&
1606                dstType.getLayout().isIdentity()) {
1607       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1608                                                    rewriter.getIndexAttr(1));
1609       Value stride = c1;
1610       for (auto dimIndex :
1611            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1612         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1613         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1614       }
1615     } else {
1616       // There could be mixed static/dynamic strides. For simplicity, we
1617       // recompute all strides if there is at least one dynamic stride.
1618       fillInDynamicStridesForMemDescriptor(
1619           rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
1620           srcDesc, dstDesc, reshapeOp.getReassociationIndices());
1621     }
1622     rewriter.replaceOp(reshapeOp, {dstDesc});
1623     return success();
1624   }
1625 };
1626 
1627 /// Conversion pattern that transforms a subview op into:
1628 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1629 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1630 ///      and stride.
1631 /// The subview op is replaced by the descriptor.
1632 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1633   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1634 
1635   LogicalResult
1636   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1637                   ConversionPatternRewriter &rewriter) const override {
1638     auto loc = subViewOp.getLoc();
1639 
1640     auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>();
1641     auto sourceElementTy =
1642         typeConverter->convertType(sourceMemRefType.getElementType());
1643 
1644     auto viewMemRefType = subViewOp.getType();
1645     auto inferredType =
1646         memref::SubViewOp::inferResultType(
1647             subViewOp.getSourceType(),
1648             extractFromI64ArrayAttr(subViewOp.getStaticOffsets()),
1649             extractFromI64ArrayAttr(subViewOp.getStaticSizes()),
1650             extractFromI64ArrayAttr(subViewOp.getStaticStrides()))
1651             .cast<MemRefType>();
1652     auto targetElementTy =
1653         typeConverter->convertType(viewMemRefType.getElementType());
1654     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1655     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1656         !LLVM::isCompatibleType(sourceElementTy) ||
1657         !LLVM::isCompatibleType(targetElementTy) ||
1658         !LLVM::isCompatibleType(targetDescTy))
1659       return failure();
1660 
1661     // Extract the offset and strides from the type.
1662     int64_t offset;
1663     SmallVector<int64_t, 4> strides;
1664     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1665     if (failed(successStrides))
1666       return failure();
1667 
1668     // Create the descriptor.
1669     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1670       return failure();
1671     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1672     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1673 
1674     // Copy the buffer pointer from the old descriptor to the new one.
1675     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1676     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1677         loc,
1678         LLVM::LLVMPointerType::get(targetElementTy,
1679                                    viewMemRefType.getMemorySpaceAsInt()),
1680         extracted);
1681     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1682 
1683     // Copy the aligned pointer from the old descriptor to the new one.
1684     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1685     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1686         loc,
1687         LLVM::LLVMPointerType::get(targetElementTy,
1688                                    viewMemRefType.getMemorySpaceAsInt()),
1689         extracted);
1690     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1691 
1692     size_t inferredShapeRank = inferredType.getRank();
1693     size_t resultShapeRank = viewMemRefType.getRank();
1694 
1695     // Extract strides needed to compute offset.
1696     SmallVector<Value, 4> strideValues;
1697     strideValues.reserve(inferredShapeRank);
1698     for (unsigned i = 0; i < inferredShapeRank; ++i)
1699       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1700 
1701     // Offset.
1702     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1703     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1704       targetMemRef.setConstantOffset(rewriter, loc, offset);
1705     } else {
1706       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1707       // `inferredShapeRank` may be larger than the number of offset operands
1708       // because of trailing semantics. In this case, the offset is guaranteed
1709       // to be interpreted as 0 and we can just skip the extra dimensions.
1710       for (unsigned i = 0, e = std::min(inferredShapeRank,
1711                                         subViewOp.getMixedOffsets().size());
1712            i < e; ++i) {
1713         Value offset =
1714             // TODO: need OpFoldResult ODS adaptor to clean this up.
1715             subViewOp.isDynamicOffset(i)
1716                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1717                 : rewriter.create<LLVM::ConstantOp>(
1718                       loc, llvmIndexType,
1719                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1720         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1721         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1722       }
1723       targetMemRef.setOffset(rewriter, loc, baseOffset);
1724     }
1725 
1726     // Update sizes and strides.
1727     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1728     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1729     assert(mixedSizes.size() == mixedStrides.size() &&
1730            "expected sizes and strides of equal length");
1731     llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1732     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1733          i >= 0 && j >= 0; --i) {
1734       if (unusedDims.test(i))
1735         continue;
1736 
1737       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1738       // In this case, the size is guaranteed to be interpreted as Dim and the
1739       // stride as 1.
1740       Value size, stride;
1741       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1742         // If the static size is available, use it directly. This is similar to
1743         // the folding of dim(constant-op) but removes the need for dim to be
1744         // aware of LLVM constants and for this pass to be aware of std
1745         // constants.
1746         int64_t staticSize =
1747             subViewOp.getSource().getType().cast<MemRefType>().getShape()[i];
1748         if (staticSize != ShapedType::kDynamicSize) {
1749           size = rewriter.create<LLVM::ConstantOp>(
1750               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1751         } else {
1752           Value pos = rewriter.create<LLVM::ConstantOp>(
1753               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1754           Value dim =
1755               rewriter.create<memref::DimOp>(loc, subViewOp.getSource(), pos);
1756           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1757               loc, llvmIndexType, dim);
1758           size = cast.getResult(0);
1759         }
1760         stride = rewriter.create<LLVM::ConstantOp>(
1761             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1762       } else {
1763         // TODO: need OpFoldResult ODS adaptor to clean this up.
1764         size =
1765             subViewOp.isDynamicSize(i)
1766                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1767                 : rewriter.create<LLVM::ConstantOp>(
1768                       loc, llvmIndexType,
1769                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1770         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1771           stride = rewriter.create<LLVM::ConstantOp>(
1772               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1773         } else {
1774           stride =
1775               subViewOp.isDynamicStride(i)
1776                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1777                   : rewriter.create<LLVM::ConstantOp>(
1778                         loc, llvmIndexType,
1779                         rewriter.getI64IntegerAttr(
1780                             subViewOp.getStaticStride(i)));
1781           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1782         }
1783       }
1784       targetMemRef.setSize(rewriter, loc, j, size);
1785       targetMemRef.setStride(rewriter, loc, j, stride);
1786       j--;
1787     }
1788 
1789     rewriter.replaceOp(subViewOp, {targetMemRef});
1790     return success();
1791   }
1792 };
1793 
1794 /// Conversion pattern that transforms a transpose op into:
1795 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1796 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1797 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1798 ///      and stride. Size and stride are permutations of the original values.
1799 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1800 /// The transpose op is replaced by the alloca'ed pointer.
1801 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1802 public:
1803   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1804 
1805   LogicalResult
1806   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1807                   ConversionPatternRewriter &rewriter) const override {
1808     auto loc = transposeOp.getLoc();
1809     MemRefDescriptor viewMemRef(adaptor.getIn());
1810 
1811     // No permutation, early exit.
1812     if (transposeOp.getPermutation().isIdentity())
1813       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1814 
1815     auto targetMemRef = MemRefDescriptor::undef(
1816         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1817 
1818     // Copy the base and aligned pointers from the old descriptor to the new
1819     // one.
1820     targetMemRef.setAllocatedPtr(rewriter, loc,
1821                                  viewMemRef.allocatedPtr(rewriter, loc));
1822     targetMemRef.setAlignedPtr(rewriter, loc,
1823                                viewMemRef.alignedPtr(rewriter, loc));
1824 
1825     // Copy the offset pointer from the old descriptor to the new one.
1826     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1827 
1828     // Iterate over the dimensions and apply size/stride permutation.
1829     for (const auto &en :
1830          llvm::enumerate(transposeOp.getPermutation().getResults())) {
1831       int sourcePos = en.index();
1832       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1833       targetMemRef.setSize(rewriter, loc, targetPos,
1834                            viewMemRef.size(rewriter, loc, sourcePos));
1835       targetMemRef.setStride(rewriter, loc, targetPos,
1836                              viewMemRef.stride(rewriter, loc, sourcePos));
1837     }
1838 
1839     rewriter.replaceOp(transposeOp, {targetMemRef});
1840     return success();
1841   }
1842 };
1843 
1844 /// Conversion pattern that transforms an op into:
1845 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1846 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1847 ///      and stride.
1848 /// The view op is replaced by the descriptor.
1849 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1850   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1851 
1852   // Build and return the value for the idx^th shape dimension, either by
1853   // returning the constant shape dimension or counting the proper dynamic size.
1854   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1855                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1856                 unsigned idx) const {
1857     assert(idx < shape.size());
1858     if (!ShapedType::isDynamic(shape[idx]))
1859       return createIndexConstant(rewriter, loc, shape[idx]);
1860     // Count the number of dynamic dims in range [0, idx]
1861     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1862       return ShapedType::isDynamic(v);
1863     });
1864     return dynamicSizes[nDynamic];
1865   }
1866 
1867   // Build and return the idx^th stride, either by returning the constant stride
1868   // or by computing the dynamic stride from the current `runningStride` and
1869   // `nextSize`. The caller should keep a running stride and update it with the
1870   // result returned by this function.
1871   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1872                   ArrayRef<int64_t> strides, Value nextSize,
1873                   Value runningStride, unsigned idx) const {
1874     assert(idx < strides.size());
1875     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1876       return createIndexConstant(rewriter, loc, strides[idx]);
1877     if (nextSize)
1878       return runningStride
1879                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1880                  : nextSize;
1881     assert(!runningStride);
1882     return createIndexConstant(rewriter, loc, 1);
1883   }
1884 
1885   LogicalResult
1886   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1887                   ConversionPatternRewriter &rewriter) const override {
1888     auto loc = viewOp.getLoc();
1889 
1890     auto viewMemRefType = viewOp.getType();
1891     auto targetElementTy =
1892         typeConverter->convertType(viewMemRefType.getElementType());
1893     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1894     if (!targetDescTy || !targetElementTy ||
1895         !LLVM::isCompatibleType(targetElementTy) ||
1896         !LLVM::isCompatibleType(targetDescTy))
1897       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1898              failure();
1899 
1900     int64_t offset;
1901     SmallVector<int64_t, 4> strides;
1902     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1903     if (failed(successStrides))
1904       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1905     assert(offset == 0 && "expected offset to be 0");
1906 
1907     // Target memref must be contiguous in memory (innermost stride is 1), or
1908     // empty (special case when at least one of the memref dimensions is 0).
1909     if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1910       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1911              failure();
1912 
1913     // Create the descriptor.
1914     MemRefDescriptor sourceMemRef(adaptor.getSource());
1915     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1916 
1917     // Field 1: Copy the allocated pointer, used for malloc/free.
1918     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1919     auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
1920     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1921         loc,
1922         LLVM::LLVMPointerType::get(targetElementTy,
1923                                    srcMemRefType.getMemorySpaceAsInt()),
1924         allocatedPtr);
1925     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1926 
1927     // Field 2: Copy the actual aligned pointer to payload.
1928     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1929     alignedPtr = rewriter.create<LLVM::GEPOp>(
1930         loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift());
1931     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1932         loc,
1933         LLVM::LLVMPointerType::get(targetElementTy,
1934                                    srcMemRefType.getMemorySpaceAsInt()),
1935         alignedPtr);
1936     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1937 
1938     // Field 3: The offset in the resulting type must be 0. This is because of
1939     // the type change: an offset on srcType* may not be expressible as an
1940     // offset on dstType*.
1941     targetMemRef.setOffset(rewriter, loc,
1942                            createIndexConstant(rewriter, loc, offset));
1943 
1944     // Early exit for 0-D corner case.
1945     if (viewMemRefType.getRank() == 0)
1946       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1947 
1948     // Fields 4 and 5: Update sizes and strides.
1949     Value stride = nullptr, nextSize = nullptr;
1950     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1951       // Update size.
1952       Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1953                            adaptor.getSizes(), i);
1954       targetMemRef.setSize(rewriter, loc, i, size);
1955       // Update stride.
1956       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1957       targetMemRef.setStride(rewriter, loc, i, stride);
1958       nextSize = size;
1959     }
1960 
1961     rewriter.replaceOp(viewOp, {targetMemRef});
1962     return success();
1963   }
1964 };
1965 
1966 //===----------------------------------------------------------------------===//
1967 // AtomicRMWOpLowering
1968 //===----------------------------------------------------------------------===//
1969 
1970 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1971 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1972 static Optional<LLVM::AtomicBinOp>
1973 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1974   switch (atomicOp.getKind()) {
1975   case arith::AtomicRMWKind::addf:
1976     return LLVM::AtomicBinOp::fadd;
1977   case arith::AtomicRMWKind::addi:
1978     return LLVM::AtomicBinOp::add;
1979   case arith::AtomicRMWKind::assign:
1980     return LLVM::AtomicBinOp::xchg;
1981   case arith::AtomicRMWKind::maxs:
1982     return LLVM::AtomicBinOp::max;
1983   case arith::AtomicRMWKind::maxu:
1984     return LLVM::AtomicBinOp::umax;
1985   case arith::AtomicRMWKind::mins:
1986     return LLVM::AtomicBinOp::min;
1987   case arith::AtomicRMWKind::minu:
1988     return LLVM::AtomicBinOp::umin;
1989   case arith::AtomicRMWKind::ori:
1990     return LLVM::AtomicBinOp::_or;
1991   case arith::AtomicRMWKind::andi:
1992     return LLVM::AtomicBinOp::_and;
1993   default:
1994     return llvm::None;
1995   }
1996   llvm_unreachable("Invalid AtomicRMWKind");
1997 }
1998 
1999 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
2000   using Base::Base;
2001 
2002   LogicalResult
2003   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
2004                   ConversionPatternRewriter &rewriter) const override {
2005     if (failed(match(atomicOp)))
2006       return failure();
2007     auto maybeKind = matchSimpleAtomicOp(atomicOp);
2008     if (!maybeKind)
2009       return failure();
2010     auto resultType = adaptor.getValue().getType();
2011     auto memRefType = atomicOp.getMemRefType();
2012     auto dataPtr =
2013         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
2014                              adaptor.getIndices(), rewriter);
2015     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
2016         atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
2017         LLVM::AtomicOrdering::acq_rel);
2018     return success();
2019   }
2020 };
2021 
2022 } // namespace
2023 
2024 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
2025                                                   RewritePatternSet &patterns) {
2026   // clang-format off
2027   patterns.add<
2028       AllocaOpLowering,
2029       AllocaScopeOpLowering,
2030       AtomicRMWOpLowering,
2031       AssumeAlignmentOpLowering,
2032       DimOpLowering,
2033       GenericAtomicRMWOpLowering,
2034       GlobalMemrefOpLowering,
2035       GetGlobalMemrefOpLowering,
2036       LoadOpLowering,
2037       MemRefCastOpLowering,
2038       MemRefCopyOpLowering,
2039       MemRefReinterpretCastOpLowering,
2040       MemRefReshapeOpLowering,
2041       PrefetchOpLowering,
2042       RankOpLowering,
2043       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2044       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2045       StoreOpLowering,
2046       SubViewOpLowering,
2047       TransposeOpLowering,
2048       ViewOpLowering>(converter);
2049   // clang-format on
2050   auto allocLowering = converter.getOptions().allocLowering;
2051   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
2052     patterns.add<AlignedAllocOpLowering, AlignedDeallocOpLowering>(converter);
2053   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2054     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
2055 }
2056 
2057 namespace {
2058 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
2059   MemRefToLLVMPass() = default;
2060 
2061   void runOnOperation() override {
2062     Operation *op = getOperation();
2063     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2064     LowerToLLVMOptions options(&getContext(),
2065                                dataLayoutAnalysis.getAtOrAbove(op));
2066     options.allocLowering =
2067         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
2068                          : LowerToLLVMOptions::AllocLowering::Malloc);
2069     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2070       options.overrideIndexBitwidth(indexBitwidth);
2071 
2072     LLVMTypeConverter typeConverter(&getContext(), options,
2073                                     &dataLayoutAnalysis);
2074     RewritePatternSet patterns(&getContext());
2075     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
2076     LLVMConversionTarget target(getContext());
2077     target.addLegalOp<func::FuncOp>();
2078     if (failed(applyPartialConversion(op, target, std::move(patterns))))
2079       signalPassFailure();
2080   }
2081 };
2082 } // namespace
2083 
2084 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
2085   return std::make_unique<MemRefToLLVMPass>();
2086 }
2087