1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 
29 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
30   return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
31 }
32 
33 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
34   AllocOpLowering(LLVMTypeConverter &converter)
35       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
36                                 converter) {}
37 
38   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
39                                           Location loc, Value sizeBytes,
40                                           Operation *op) const override {
41     // Heap allocations.
42     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
43     MemRefType memRefType = allocOp.getType();
44 
45     Value alignment;
46     if (auto alignmentAttr = allocOp.getAlignment()) {
47       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
48     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
49       // In the case where no alignment is specified, we may want to override
50       // `malloc's` behavior. `malloc` typically aligns at the size of the
51       // biggest scalar on a target HW. For non-scalars, use the natural
52       // alignment of the LLVM type given by the LLVM DataLayout.
53       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
54     }
55 
56     if (alignment) {
57       // Adjust the allocation size to consider alignment.
58       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
59     }
60 
61     // Allocate the underlying buffer and store a pointer to it in the MemRef
62     // descriptor.
63     Type elementPtrType = this->getElementPtrType(memRefType);
64     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
65         allocOp->getParentOfType<ModuleOp>(), getIndexType());
66     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
67                                   getVoidPtrType());
68     Value allocatedPtr =
69         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
70 
71     Value alignedPtr = allocatedPtr;
72     if (alignment) {
73       // Compute the aligned type pointer.
74       Value allocatedInt =
75           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
76       Value alignmentInt =
77           createAligned(rewriter, loc, allocatedInt, alignment);
78       alignedPtr =
79           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
80     }
81 
82     return std::make_tuple(allocatedPtr, alignedPtr);
83   }
84 };
85 
86 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
87   AlignedAllocOpLowering(LLVMTypeConverter &converter)
88       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
89                                 converter) {}
90 
91   /// Returns the memref's element size in bytes using the data layout active at
92   /// `op`.
93   // TODO: there are other places where this is used. Expose publicly?
94   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
95     const DataLayout *layout = &defaultLayout;
96     if (const DataLayoutAnalysis *analysis =
97             getTypeConverter()->getDataLayoutAnalysis()) {
98       layout = &analysis->getAbove(op);
99     }
100     Type elementType = memRefType.getElementType();
101     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
102       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
103                                                          *layout);
104     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
105       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
106           memRefElementType, *layout);
107     return layout->getTypeSize(elementType);
108   }
109 
110   /// Returns true if the memref size in bytes is known to be a multiple of
111   /// factor assuming the data layout active at `op`.
112   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
113                               Operation *op) const {
114     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
115     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
116       if (ShapedType::isDynamic(type.getDimSize(i)))
117         continue;
118       sizeDivisor = sizeDivisor * type.getDimSize(i);
119     }
120     return sizeDivisor % factor == 0;
121   }
122 
123   /// Returns the alignment to be used for the allocation call itself.
124   /// aligned_alloc requires the allocation size to be a power of two, and the
125   /// allocation size to be a multiple of alignment,
126   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
127     if (Optional<uint64_t> alignment = allocOp.getAlignment())
128       return *alignment;
129 
130     // Whenever we don't have alignment set, we will use an alignment
131     // consistent with the element type; since the allocation size has to be a
132     // power of two, we will bump to the next power of two if it already isn't.
133     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
134     return std::max(kMinAlignedAllocAlignment,
135                     llvm::PowerOf2Ceil(eltSizeBytes));
136   }
137 
138   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
139                                           Location loc, Value sizeBytes,
140                                           Operation *op) const override {
141     // Heap allocations.
142     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
143     MemRefType memRefType = allocOp.getType();
144     int64_t alignment = getAllocationAlignment(allocOp);
145     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
146 
147     // aligned_alloc requires size to be a multiple of alignment; we will pad
148     // the size to the next multiple if necessary.
149     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
150       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
151 
152     Type elementPtrType = this->getElementPtrType(memRefType);
153     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
154         allocOp->getParentOfType<ModuleOp>(), getIndexType());
155     auto results =
156         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
157                        getVoidPtrType());
158     Value allocatedPtr =
159         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
160 
161     return std::make_tuple(allocatedPtr, allocatedPtr);
162   }
163 
164   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
165   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
166 
167   /// Default layout to use in absence of the corresponding analysis.
168   DataLayout defaultLayout;
169 };
170 
171 // Out of line definition, required till C++17.
172 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
173 
174 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
175   AllocaOpLowering(LLVMTypeConverter &converter)
176       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
177                                 converter) {}
178 
179   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
180   /// is set to null for stack allocations. `accessAlignment` is set if
181   /// alignment is needed post allocation (for eg. in conjunction with malloc).
182   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
183                                           Location loc, Value sizeBytes,
184                                           Operation *op) const override {
185 
186     // With alloca, one gets a pointer to the element type right away.
187     // For stack allocations.
188     auto allocaOp = cast<memref::AllocaOp>(op);
189     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
190 
191     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
192         loc, elementPtrType, sizeBytes,
193         allocaOp.getAlignment() ? *allocaOp.getAlignment() : 0);
194 
195     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
196   }
197 };
198 
199 struct AllocaScopeOpLowering
200     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
201   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
202 
203   LogicalResult
204   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
205                   ConversionPatternRewriter &rewriter) const override {
206     OpBuilder::InsertionGuard guard(rewriter);
207     Location loc = allocaScopeOp.getLoc();
208 
209     // Split the current block before the AllocaScopeOp to create the inlining
210     // point.
211     auto *currentBlock = rewriter.getInsertionBlock();
212     auto *remainingOpsBlock =
213         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
214     Block *continueBlock;
215     if (allocaScopeOp.getNumResults() == 0) {
216       continueBlock = remainingOpsBlock;
217     } else {
218       continueBlock = rewriter.createBlock(
219           remainingOpsBlock, allocaScopeOp.getResultTypes(),
220           SmallVector<Location>(allocaScopeOp->getNumResults(),
221                                 allocaScopeOp.getLoc()));
222       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
223     }
224 
225     // Inline body region.
226     Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
227     Block *afterBody = &allocaScopeOp.getBodyRegion().back();
228     rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
229 
230     // Save stack and then branch into the body of the region.
231     rewriter.setInsertionPointToEnd(currentBlock);
232     auto stackSaveOp =
233         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
234     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
235 
236     // Replace the alloca_scope return with a branch that jumps out of the body.
237     // Stack restore before leaving the body region.
238     rewriter.setInsertionPointToEnd(afterBody);
239     auto returnOp =
240         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
241     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
242         returnOp, returnOp.getResults(), continueBlock);
243 
244     // Insert stack restore before jumping out the body of the region.
245     rewriter.setInsertionPoint(branchOp);
246     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
247 
248     // Replace the op with values return from the body region.
249     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
250 
251     return success();
252   }
253 };
254 
255 struct AssumeAlignmentOpLowering
256     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
257   using ConvertOpToLLVMPattern<
258       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
259 
260   LogicalResult
261   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
262                   ConversionPatternRewriter &rewriter) const override {
263     Value memref = adaptor.getMemref();
264     unsigned alignment = op.getAlignment();
265     auto loc = op.getLoc();
266 
267     MemRefDescriptor memRefDescriptor(memref);
268     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
269 
270     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
271     // the asserted memref.alignedPtr isn't used anywhere else, as the real
272     // users like load/store/views always re-extract memref.alignedPtr as they
273     // get lowered.
274     //
275     // This relies on LLVM's CSE optimization (potentially after SROA), since
276     // after CSE all memref.alignedPtr instances get de-duplicated into the same
277     // pointer SSA value.
278     auto intPtrType =
279         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
280     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
281     Value mask =
282         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
283     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
284     rewriter.create<LLVM::AssumeOp>(
285         loc, rewriter.create<LLVM::ICmpOp>(
286                  loc, LLVM::ICmpPredicate::eq,
287                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
288 
289     rewriter.eraseOp(op);
290     return success();
291   }
292 };
293 
294 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
295 // The memref descriptor being an SSA value, there is no need to clean it up
296 // in any way.
297 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
298   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
299 
300   explicit DeallocOpLowering(LLVMTypeConverter &converter)
301       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
302 
303   LogicalResult
304   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
305                   ConversionPatternRewriter &rewriter) const override {
306     // Insert the `free` declaration if it is not already present.
307     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
308     MemRefDescriptor memref(adaptor.getMemref());
309     Value casted = rewriter.create<LLVM::BitcastOp>(
310         op.getLoc(), getVoidPtrType(),
311         memref.allocatedPtr(rewriter, op.getLoc()));
312     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
313         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
314     return success();
315   }
316 };
317 
318 // A `dim` is converted to a constant for static sizes and to an access to the
319 // size stored in the memref descriptor for dynamic sizes.
320 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
321   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
322 
323   LogicalResult
324   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
325                   ConversionPatternRewriter &rewriter) const override {
326     Type operandType = dimOp.getSource().getType();
327     if (operandType.isa<UnrankedMemRefType>()) {
328       rewriter.replaceOp(
329           dimOp, {extractSizeOfUnrankedMemRef(
330                      operandType, dimOp, adaptor.getOperands(), rewriter)});
331 
332       return success();
333     }
334     if (operandType.isa<MemRefType>()) {
335       rewriter.replaceOp(
336           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
337                                             adaptor.getOperands(), rewriter)});
338       return success();
339     }
340     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
341   }
342 
343 private:
344   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
345                                     OpAdaptor adaptor,
346                                     ConversionPatternRewriter &rewriter) const {
347     Location loc = dimOp.getLoc();
348 
349     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
350     auto scalarMemRefType =
351         MemRefType::get({}, unrankedMemRefType.getElementType());
352     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
353 
354     // Extract pointer to the underlying ranked descriptor and bitcast it to a
355     // memref<element_type> descriptor pointer to minimize the number of GEP
356     // operations.
357     UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
358     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
359     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
360         loc,
361         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
362                                    addressSpace),
363         underlyingRankedDesc);
364 
365     // Get pointer to offset field of memref<element_type> descriptor.
366     Type indexPtrTy = LLVM::LLVMPointerType::get(
367         getTypeConverter()->getIndexType(), addressSpace);
368     Value two = rewriter.create<LLVM::ConstantOp>(
369         loc, typeConverter->convertType(rewriter.getI32Type()),
370         rewriter.getI32IntegerAttr(2));
371     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
372         loc, indexPtrTy, scalarMemRefDescPtr,
373         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
374 
375     // The size value that we have to extract can be obtained using GEPop with
376     // `dimOp.index() + 1` index argument.
377     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
378         loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
379     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
380                                                  ValueRange({idxPlusOne}));
381     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
382   }
383 
384   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
385     if (Optional<int64_t> idx = dimOp.getConstantIndex())
386       return idx;
387 
388     if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
389       return constantOp.getValue()
390           .cast<IntegerAttr>()
391           .getValue()
392           .getSExtValue();
393 
394     return llvm::None;
395   }
396 
397   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
398                                   OpAdaptor adaptor,
399                                   ConversionPatternRewriter &rewriter) const {
400     Location loc = dimOp.getLoc();
401 
402     // Take advantage if index is constant.
403     MemRefType memRefType = operandType.cast<MemRefType>();
404     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
405       int64_t i = *index;
406       if (memRefType.isDynamicDim(i)) {
407         // extract dynamic size from the memref descriptor.
408         MemRefDescriptor descriptor(adaptor.getSource());
409         return descriptor.size(rewriter, loc, i);
410       }
411       // Use constant for static size.
412       int64_t dimSize = memRefType.getDimSize(i);
413       return createIndexConstant(rewriter, loc, dimSize);
414     }
415     Value index = adaptor.getIndex();
416     int64_t rank = memRefType.getRank();
417     MemRefDescriptor memrefDescriptor(adaptor.getSource());
418     return memrefDescriptor.size(rewriter, loc, index, rank);
419   }
420 };
421 
422 /// Common base for load and store operations on MemRefs. Restricts the match
423 /// to supported MemRef types. Provides functionality to emit code accessing a
424 /// specific element of the underlying data buffer.
425 template <typename Derived>
426 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
427   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
428   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
429   using Base = LoadStoreOpLowering<Derived>;
430 
431   LogicalResult match(Derived op) const override {
432     MemRefType type = op.getMemRefType();
433     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
434   }
435 };
436 
437 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
438 /// retried until it succeeds in atomically storing a new value into memory.
439 ///
440 ///      +---------------------------------+
441 ///      |   <code before the AtomicRMWOp> |
442 ///      |   <compute initial %loaded>     |
443 ///      |   cf.br loop(%loaded)              |
444 ///      +---------------------------------+
445 ///             |
446 ///  -------|   |
447 ///  |      v   v
448 ///  |   +--------------------------------+
449 ///  |   | loop(%loaded):                 |
450 ///  |   |   <body contents>              |
451 ///  |   |   %pair = cmpxchg              |
452 ///  |   |   %ok = %pair[0]               |
453 ///  |   |   %new = %pair[1]              |
454 ///  |   |   cf.cond_br %ok, end, loop(%new) |
455 ///  |   +--------------------------------+
456 ///  |          |        |
457 ///  |-----------        |
458 ///                      v
459 ///      +--------------------------------+
460 ///      | end:                           |
461 ///      |   <code after the AtomicRMWOp> |
462 ///      +--------------------------------+
463 ///
464 struct GenericAtomicRMWOpLowering
465     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
466   using Base::Base;
467 
468   LogicalResult
469   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
470                   ConversionPatternRewriter &rewriter) const override {
471     auto loc = atomicOp.getLoc();
472     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
473 
474     // Split the block into initial, loop, and ending parts.
475     auto *initBlock = rewriter.getInsertionBlock();
476     auto *loopBlock = rewriter.createBlock(
477         initBlock->getParent(), std::next(Region::iterator(initBlock)),
478         valueType, loc);
479     auto *endBlock = rewriter.createBlock(
480         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
481 
482     // Operations range to be moved to `endBlock`.
483     auto opsToMoveStart = atomicOp->getIterator();
484     auto opsToMoveEnd = initBlock->back().getIterator();
485 
486     // Compute the loaded value and branch to the loop block.
487     rewriter.setInsertionPointToEnd(initBlock);
488     auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
489     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
490                                         adaptor.getIndices(), rewriter);
491     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
492     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
493 
494     // Prepare the body of the loop block.
495     rewriter.setInsertionPointToStart(loopBlock);
496 
497     // Clone the GenericAtomicRMWOp region and extract the result.
498     auto loopArgument = loopBlock->getArgument(0);
499     BlockAndValueMapping mapping;
500     mapping.map(atomicOp.getCurrentValue(), loopArgument);
501     Block &entryBlock = atomicOp.body().front();
502     for (auto &nestedOp : entryBlock.without_terminator()) {
503       Operation *clone = rewriter.clone(nestedOp, mapping);
504       mapping.map(nestedOp.getResults(), clone->getResults());
505     }
506     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
507 
508     // Prepare the epilog of the loop block.
509     // Append the cmpxchg op to the end of the loop block.
510     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
511     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
512     auto boolType = IntegerType::get(rewriter.getContext(), 1);
513     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
514                                                      {valueType, boolType});
515     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
516         loc, pairType, dataPtr, loopArgument, result, successOrdering,
517         failureOrdering);
518     // Extract the %new_loaded and %ok values from the pair.
519     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
520         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
521     Value ok = rewriter.create<LLVM::ExtractValueOp>(
522         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
523 
524     // Conditionally branch to the end or back to the loop depending on %ok.
525     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
526                                     loopBlock, newLoaded);
527 
528     rewriter.setInsertionPointToEnd(endBlock);
529     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
530                  std::next(opsToMoveEnd), rewriter);
531 
532     // The 'result' of the atomic_rmw op is the newly loaded value.
533     rewriter.replaceOp(atomicOp, {newLoaded});
534 
535     return success();
536   }
537 
538 private:
539   // Clones a segment of ops [start, end) and erases the original.
540   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
541                     Block::iterator start, Block::iterator end,
542                     ConversionPatternRewriter &rewriter) const {
543     BlockAndValueMapping mapping;
544     mapping.map(oldResult, newResult);
545     SmallVector<Operation *, 2> opsToErase;
546     for (auto it = start; it != end; ++it) {
547       rewriter.clone(*it, mapping);
548       opsToErase.push_back(&*it);
549     }
550     for (auto *it : opsToErase)
551       rewriter.eraseOp(it);
552   }
553 };
554 
555 /// Returns the LLVM type of the global variable given the memref type `type`.
556 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
557                                           LLVMTypeConverter &typeConverter) {
558   // LLVM type for a global memref will be a multi-dimension array. For
559   // declarations or uninitialized global memrefs, we can potentially flatten
560   // this to a 1D array. However, for memref.global's with an initial value,
561   // we do not intend to flatten the ElementsAttribute when going from std ->
562   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
563   Type elementType = typeConverter.convertType(type.getElementType());
564   Type arrayTy = elementType;
565   // Shape has the outermost dim at index 0, so need to walk it backwards
566   for (int64_t dim : llvm::reverse(type.getShape()))
567     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
568   return arrayTy;
569 }
570 
571 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
572 struct GlobalMemrefOpLowering
573     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
574   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
575 
576   LogicalResult
577   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
578                   ConversionPatternRewriter &rewriter) const override {
579     MemRefType type = global.getType();
580     if (!isConvertibleAndHasIdentityMaps(type))
581       return failure();
582 
583     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
584 
585     LLVM::Linkage linkage =
586         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
587 
588     Attribute initialValue = nullptr;
589     if (!global.isExternal() && !global.isUninitialized()) {
590       auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
591       initialValue = elementsAttr;
592 
593       // For scalar memrefs, the global variable created is of the element type,
594       // so unpack the elements attribute to extract the value.
595       if (type.getRank() == 0)
596         initialValue = elementsAttr.getSplatValue<Attribute>();
597     }
598 
599     uint64_t alignment = global.getAlignment().value_or(0);
600 
601     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
602         global, arrayTy, global.getConstant(), linkage, global.getSymName(),
603         initialValue, alignment, type.getMemorySpaceAsInt());
604     if (!global.isExternal() && global.isUninitialized()) {
605       Block *blk = new Block();
606       newGlobal.getInitializerRegion().push_back(blk);
607       rewriter.setInsertionPointToStart(blk);
608       Value undef[] = {
609           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
610       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
611     }
612     return success();
613   }
614 };
615 
616 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
617 /// the first element stashed into the descriptor. This reuses
618 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
619 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
620   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
621       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
622                                 converter) {}
623 
624   /// Buffer "allocation" for memref.get_global op is getting the address of
625   /// the global variable referenced.
626   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
627                                           Location loc, Value sizeBytes,
628                                           Operation *op) const override {
629     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
630     MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
631     unsigned memSpace = type.getMemorySpaceAsInt();
632 
633     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
634     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
635         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace),
636         getGlobalOp.getName());
637 
638     // Get the address of the first element in the array by creating a GEP with
639     // the address of the GV as the base, and (rank + 1) number of 0 indices.
640     Type elementType = typeConverter->convertType(type.getElementType());
641     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
642 
643     SmallVector<Value> operands;
644     operands.insert(operands.end(), type.getRank() + 1,
645                     createIndexConstant(rewriter, loc, 0));
646     auto gep =
647         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
648 
649     // We do not expect the memref obtained using `memref.get_global` to be
650     // ever deallocated. Set the allocated pointer to be known bad value to
651     // help debug if that ever happens.
652     auto intPtrType = getIntPtrType(memSpace);
653     Value deadBeefConst =
654         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
655     auto deadBeefPtr =
656         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
657 
658     // Both allocated and aligned pointers are same. We could potentially stash
659     // a nullptr for the allocated pointer since we do not expect any dealloc.
660     return std::make_tuple(deadBeefPtr, gep);
661   }
662 };
663 
664 // Load operation is lowered to obtaining a pointer to the indexed element
665 // and loading it.
666 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
667   using Base::Base;
668 
669   LogicalResult
670   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
671                   ConversionPatternRewriter &rewriter) const override {
672     auto type = loadOp.getMemRefType();
673 
674     Value dataPtr =
675         getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
676                              adaptor.getIndices(), rewriter);
677     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
678     return success();
679   }
680 };
681 
682 // Store operation is lowered to obtaining a pointer to the indexed element,
683 // and storing the given value to it.
684 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
685   using Base::Base;
686 
687   LogicalResult
688   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
689                   ConversionPatternRewriter &rewriter) const override {
690     auto type = op.getMemRefType();
691 
692     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
693                                          adaptor.getIndices(), rewriter);
694     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr);
695     return success();
696   }
697 };
698 
699 // The prefetch operation is lowered in a way similar to the load operation
700 // except that the llvm.prefetch operation is used for replacement.
701 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
702   using Base::Base;
703 
704   LogicalResult
705   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
706                   ConversionPatternRewriter &rewriter) const override {
707     auto type = prefetchOp.getMemRefType();
708     auto loc = prefetchOp.getLoc();
709 
710     Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
711                                          adaptor.getIndices(), rewriter);
712 
713     // Replace with llvm.prefetch.
714     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
715     auto isWrite = rewriter.create<LLVM::ConstantOp>(
716         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()));
717     auto localityHint = rewriter.create<LLVM::ConstantOp>(
718         loc, llvmI32Type,
719         rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint()));
720     auto isData = rewriter.create<LLVM::ConstantOp>(
721         loc, llvmI32Type,
722         rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache()));
723 
724     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
725                                                 localityHint, isData);
726     return success();
727   }
728 };
729 
730 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
731   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
732 
733   LogicalResult
734   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
735                   ConversionPatternRewriter &rewriter) const override {
736     Location loc = op.getLoc();
737     Type operandType = op.getMemref().getType();
738     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
739       UnrankedMemRefDescriptor desc(adaptor.getMemref());
740       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
741       return success();
742     }
743     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
744       rewriter.replaceOp(
745           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
746       return success();
747     }
748     return failure();
749   }
750 };
751 
752 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
753   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
754 
755   LogicalResult match(memref::CastOp memRefCastOp) const override {
756     Type srcType = memRefCastOp.getOperand().getType();
757     Type dstType = memRefCastOp.getType();
758 
759     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
760     // used for type erasure. For now they must preserve underlying element type
761     // and require source and result type to have the same rank. Therefore,
762     // perform a sanity check that the underlying structs are the same. Once op
763     // semantics are relaxed we can revisit.
764     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
765       return success(typeConverter->convertType(srcType) ==
766                      typeConverter->convertType(dstType));
767 
768     // At least one of the operands is unranked type
769     assert(srcType.isa<UnrankedMemRefType>() ||
770            dstType.isa<UnrankedMemRefType>());
771 
772     // Unranked to unranked cast is disallowed
773     return !(srcType.isa<UnrankedMemRefType>() &&
774              dstType.isa<UnrankedMemRefType>())
775                ? success()
776                : failure();
777   }
778 
779   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
780                ConversionPatternRewriter &rewriter) const override {
781     auto srcType = memRefCastOp.getOperand().getType();
782     auto dstType = memRefCastOp.getType();
783     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
784     auto loc = memRefCastOp.getLoc();
785 
786     // For ranked/ranked case, just keep the original descriptor.
787     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
788       return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
789 
790     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
791       // Casting ranked to unranked memref type
792       // Set the rank in the destination from the memref type
793       // Allocate space on the stack and copy the src memref descriptor
794       // Set the ptr in the destination to the stack space
795       auto srcMemRefType = srcType.cast<MemRefType>();
796       int64_t rank = srcMemRefType.getRank();
797       // ptr = AllocaOp sizeof(MemRefDescriptor)
798       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
799           loc, adaptor.getSource(), rewriter);
800       // voidptr = BitCastOp srcType* to void*
801       auto voidPtr =
802           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
803               .getResult();
804       // rank = ConstantOp srcRank
805       auto rankVal = rewriter.create<LLVM::ConstantOp>(
806           loc, getIndexType(), rewriter.getIndexAttr(rank));
807       // undef = UndefOp
808       UnrankedMemRefDescriptor memRefDesc =
809           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
810       // d1 = InsertValueOp undef, rank, 0
811       memRefDesc.setRank(rewriter, loc, rankVal);
812       // d2 = InsertValueOp d1, voidptr, 1
813       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
814       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
815 
816     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
817       // Casting from unranked type to ranked.
818       // The operation is assumed to be doing a correct cast. If the destination
819       // type mismatches the unranked the type, it is undefined behavior.
820       UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
821       // ptr = ExtractValueOp src, 1
822       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
823       // castPtr = BitCastOp i8* to structTy*
824       auto castPtr =
825           rewriter
826               .create<LLVM::BitcastOp>(
827                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
828               .getResult();
829       // struct = LoadOp castPtr
830       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
831       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
832     } else {
833       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
834     }
835   }
836 };
837 
838 /// Pattern to lower a `memref.copy` to llvm.
839 ///
840 /// For memrefs with identity layouts, the copy is lowered to the llvm
841 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
842 /// to the generic `MemrefCopyFn`.
843 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
844   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
845 
846   LogicalResult
847   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
848                           ConversionPatternRewriter &rewriter) const {
849     auto loc = op.getLoc();
850     auto srcType = op.getSource().getType().dyn_cast<MemRefType>();
851 
852     MemRefDescriptor srcDesc(adaptor.getSource());
853 
854     // Compute number of elements.
855     Value numElements = rewriter.create<LLVM::ConstantOp>(
856         loc, getIndexType(), rewriter.getIndexAttr(1));
857     for (int pos = 0; pos < srcType.getRank(); ++pos) {
858       auto size = srcDesc.size(rewriter, loc, pos);
859       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
860     }
861 
862     // Get element size.
863     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
864     // Compute total.
865     Value totalSize =
866         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
867 
868     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
869     Value srcOffset = srcDesc.offset(rewriter, loc);
870     Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
871                                                 srcBasePtr, srcOffset);
872     MemRefDescriptor targetDesc(adaptor.getTarget());
873     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
874     Value targetOffset = targetDesc.offset(rewriter, loc);
875     Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
876                                                    targetBasePtr, targetOffset);
877     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
878         loc, typeConverter->convertType(rewriter.getI1Type()),
879         rewriter.getBoolAttr(false));
880     rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
881                                     isVolatile);
882     rewriter.eraseOp(op);
883 
884     return success();
885   }
886 
887   LogicalResult
888   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
889                              ConversionPatternRewriter &rewriter) const {
890     auto loc = op.getLoc();
891     auto srcType = op.getSource().getType().cast<BaseMemRefType>();
892     auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
893 
894     // First make sure we have an unranked memref descriptor representation.
895     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
896       auto rank = rewriter.create<LLVM::ConstantOp>(
897           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
898       auto *typeConverter = getTypeConverter();
899       auto ptr =
900           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
901       auto voidPtr =
902           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
903               .getResult();
904       auto unrankedType =
905           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
906       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
907                                             unrankedType,
908                                             ValueRange{rank, voidPtr});
909     };
910 
911     Value unrankedSource = srcType.hasRank()
912                                ? makeUnranked(adaptor.getSource(), srcType)
913                                : adaptor.getSource();
914     Value unrankedTarget = targetType.hasRank()
915                                ? makeUnranked(adaptor.getTarget(), targetType)
916                                : adaptor.getTarget();
917 
918     // Now promote the unranked descriptors to the stack.
919     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
920                                                  rewriter.getIndexAttr(1));
921     auto promote = [&](Value desc) {
922       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
923       auto allocated =
924           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
925       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
926       return allocated;
927     };
928 
929     auto sourcePtr = promote(unrankedSource);
930     auto targetPtr = promote(unrankedTarget);
931 
932     unsigned typeSize =
933         mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
934     auto elemSize = rewriter.create<LLVM::ConstantOp>(
935         loc, getIndexType(), rewriter.getIndexAttr(typeSize));
936     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
937         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
938     rewriter.create<LLVM::CallOp>(loc, copyFn,
939                                   ValueRange{elemSize, sourcePtr, targetPtr});
940     rewriter.eraseOp(op);
941 
942     return success();
943   }
944 
945   LogicalResult
946   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
947                   ConversionPatternRewriter &rewriter) const override {
948     auto srcType = op.getSource().getType().cast<BaseMemRefType>();
949     auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
950 
951     auto isContiguousMemrefType = [](BaseMemRefType type) {
952       auto memrefType = type.dyn_cast<mlir::MemRefType>();
953       // We can use memcpy for memrefs if they have an identity layout or are
954       // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
955       // special case handled by memrefCopy.
956       return memrefType &&
957              (memrefType.getLayout().isIdentity() ||
958               (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
959                isStaticShapeAndContiguousRowMajor(memrefType)));
960     };
961 
962     if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
963       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
964 
965     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
966   }
967 };
968 
969 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
970 /// memref type. In unranked case, the fields are extracted from the underlying
971 /// ranked descriptor.
972 static void extractPointersAndOffset(Location loc,
973                                      ConversionPatternRewriter &rewriter,
974                                      LLVMTypeConverter &typeConverter,
975                                      Value originalOperand,
976                                      Value convertedOperand,
977                                      Value *allocatedPtr, Value *alignedPtr,
978                                      Value *offset = nullptr) {
979   Type operandType = originalOperand.getType();
980   if (operandType.isa<MemRefType>()) {
981     MemRefDescriptor desc(convertedOperand);
982     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
983     *alignedPtr = desc.alignedPtr(rewriter, loc);
984     if (offset != nullptr)
985       *offset = desc.offset(rewriter, loc);
986     return;
987   }
988 
989   unsigned memorySpace =
990       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
991   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
992   Type llvmElementType = typeConverter.convertType(elementType);
993   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
994       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
995 
996   // Extract pointer to the underlying ranked memref descriptor and cast it to
997   // ElemType**.
998   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
999   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1000 
1001   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
1002       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
1003   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1004       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1005   if (offset != nullptr) {
1006     *offset = UnrankedMemRefDescriptor::offset(
1007         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1008   }
1009 }
1010 
1011 struct MemRefReinterpretCastOpLowering
1012     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1013   using ConvertOpToLLVMPattern<
1014       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1015 
1016   LogicalResult
1017   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1018                   ConversionPatternRewriter &rewriter) const override {
1019     Type srcType = castOp.getSource().getType();
1020 
1021     Value descriptor;
1022     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1023                                                adaptor, &descriptor)))
1024       return failure();
1025     rewriter.replaceOp(castOp, {descriptor});
1026     return success();
1027   }
1028 
1029 private:
1030   LogicalResult convertSourceMemRefToDescriptor(
1031       ConversionPatternRewriter &rewriter, Type srcType,
1032       memref::ReinterpretCastOp castOp,
1033       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1034     MemRefType targetMemRefType =
1035         castOp.getResult().getType().cast<MemRefType>();
1036     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1037                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1038     if (!llvmTargetDescriptorTy)
1039       return failure();
1040 
1041     // Create descriptor.
1042     Location loc = castOp.getLoc();
1043     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1044 
1045     // Set allocated and aligned pointers.
1046     Value allocatedPtr, alignedPtr;
1047     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1048                              castOp.getSource(), adaptor.getSource(),
1049                              &allocatedPtr, &alignedPtr);
1050     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1051     desc.setAlignedPtr(rewriter, loc, alignedPtr);
1052 
1053     // Set offset.
1054     if (castOp.isDynamicOffset(0))
1055       desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1056     else
1057       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1058 
1059     // Set sizes and strides.
1060     unsigned dynSizeId = 0;
1061     unsigned dynStrideId = 0;
1062     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1063       if (castOp.isDynamicSize(i))
1064         desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1065       else
1066         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1067 
1068       if (castOp.isDynamicStride(i))
1069         desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1070       else
1071         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1072     }
1073     *descriptor = desc;
1074     return success();
1075   }
1076 };
1077 
1078 struct MemRefReshapeOpLowering
1079     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1080   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1081 
1082   LogicalResult
1083   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1084                   ConversionPatternRewriter &rewriter) const override {
1085     Type srcType = reshapeOp.getSource().getType();
1086 
1087     Value descriptor;
1088     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1089                                                adaptor, &descriptor)))
1090       return failure();
1091     rewriter.replaceOp(reshapeOp, {descriptor});
1092     return success();
1093   }
1094 
1095 private:
1096   LogicalResult
1097   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1098                                   Type srcType, memref::ReshapeOp reshapeOp,
1099                                   memref::ReshapeOp::Adaptor adaptor,
1100                                   Value *descriptor) const {
1101     auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>();
1102     if (shapeMemRefType.hasStaticShape()) {
1103       MemRefType targetMemRefType =
1104           reshapeOp.getResult().getType().cast<MemRefType>();
1105       auto llvmTargetDescriptorTy =
1106           typeConverter->convertType(targetMemRefType)
1107               .dyn_cast_or_null<LLVM::LLVMStructType>();
1108       if (!llvmTargetDescriptorTy)
1109         return failure();
1110 
1111       // Create descriptor.
1112       Location loc = reshapeOp.getLoc();
1113       auto desc =
1114           MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1115 
1116       // Set allocated and aligned pointers.
1117       Value allocatedPtr, alignedPtr;
1118       extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1119                                reshapeOp.getSource(), adaptor.getSource(),
1120                                &allocatedPtr, &alignedPtr);
1121       desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1122       desc.setAlignedPtr(rewriter, loc, alignedPtr);
1123 
1124       // Extract the offset and strides from the type.
1125       int64_t offset;
1126       SmallVector<int64_t> strides;
1127       if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1128         return rewriter.notifyMatchFailure(
1129             reshapeOp, "failed to get stride and offset exprs");
1130 
1131       if (!isStaticStrideOrOffset(offset))
1132         return rewriter.notifyMatchFailure(reshapeOp,
1133                                            "dynamic offset is unsupported");
1134 
1135       desc.setConstantOffset(rewriter, loc, offset);
1136 
1137       assert(targetMemRefType.getLayout().isIdentity() &&
1138              "Identity layout map is a precondition of a valid reshape op");
1139 
1140       Value stride = nullptr;
1141       int64_t targetRank = targetMemRefType.getRank();
1142       for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1143         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1144           // If the stride for this dimension is dynamic, then use the product
1145           // of the sizes of the inner dimensions.
1146           stride = createIndexConstant(rewriter, loc, strides[i]);
1147         } else if (!stride) {
1148           // `stride` is null only in the first iteration of the loop.  However,
1149           // since the target memref has an identity layout, we can safely set
1150           // the innermost stride to 1.
1151           stride = createIndexConstant(rewriter, loc, 1);
1152         }
1153 
1154         Value dimSize;
1155         int64_t size = targetMemRefType.getDimSize(i);
1156         // If the size of this dimension is dynamic, then load it at runtime
1157         // from the shape operand.
1158         if (!ShapedType::isDynamic(size)) {
1159           dimSize = createIndexConstant(rewriter, loc, size);
1160         } else {
1161           Value shapeOp = reshapeOp.getShape();
1162           Value index = createIndexConstant(rewriter, loc, i);
1163           dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1164           Type indexType = getIndexType();
1165           if (dimSize.getType() != indexType)
1166             dimSize = typeConverter->materializeTargetConversion(
1167                 rewriter, loc, indexType, dimSize);
1168           assert(dimSize && "Invalid memref element type");
1169         }
1170 
1171         desc.setSize(rewriter, loc, i, dimSize);
1172         desc.setStride(rewriter, loc, i, stride);
1173 
1174         // Prepare the stride value for the next dimension.
1175         stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1176       }
1177 
1178       *descriptor = desc;
1179       return success();
1180     }
1181 
1182     // The shape is a rank-1 tensor with unknown length.
1183     Location loc = reshapeOp.getLoc();
1184     MemRefDescriptor shapeDesc(adaptor.getShape());
1185     Value resultRank = shapeDesc.size(rewriter, loc, 0);
1186 
1187     // Extract address space and element type.
1188     auto targetType =
1189         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1190     unsigned addressSpace = targetType.getMemorySpaceAsInt();
1191     Type elementType = targetType.getElementType();
1192 
1193     // Create the unranked memref descriptor that holds the ranked one. The
1194     // inner descriptor is allocated on stack.
1195     auto targetDesc = UnrankedMemRefDescriptor::undef(
1196         rewriter, loc, typeConverter->convertType(targetType));
1197     targetDesc.setRank(rewriter, loc, resultRank);
1198     SmallVector<Value, 4> sizes;
1199     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1200                                            targetDesc, sizes);
1201     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1202         loc, getVoidPtrType(), sizes.front(), llvm::None);
1203     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1204 
1205     // Extract pointers and offset from the source memref.
1206     Value allocatedPtr, alignedPtr, offset;
1207     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1208                              reshapeOp.getSource(), adaptor.getSource(),
1209                              &allocatedPtr, &alignedPtr, &offset);
1210 
1211     // Set pointers and offset.
1212     Type llvmElementType = typeConverter->convertType(elementType);
1213     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1214         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1215     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1216                                               elementPtrPtrType, allocatedPtr);
1217     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1218                                             underlyingDescPtr,
1219                                             elementPtrPtrType, alignedPtr);
1220     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1221                                         underlyingDescPtr, elementPtrPtrType,
1222                                         offset);
1223 
1224     // Use the offset pointer as base for further addressing. Copy over the new
1225     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1226     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1227         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1228         elementPtrPtrType);
1229     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1230         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1231     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1232     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1233     Value resultRankMinusOne =
1234         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1235 
1236     Block *initBlock = rewriter.getInsertionBlock();
1237     Type indexType = getTypeConverter()->getIndexType();
1238     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1239 
1240     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1241                                             {indexType, indexType}, {loc, loc});
1242 
1243     // Move the remaining initBlock ops to condBlock.
1244     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1245     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1246 
1247     rewriter.setInsertionPointToEnd(initBlock);
1248     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1249                                 condBlock);
1250     rewriter.setInsertionPointToStart(condBlock);
1251     Value indexArg = condBlock->getArgument(0);
1252     Value strideArg = condBlock->getArgument(1);
1253 
1254     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1255     Value pred = rewriter.create<LLVM::ICmpOp>(
1256         loc, IntegerType::get(rewriter.getContext(), 1),
1257         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1258 
1259     Block *bodyBlock =
1260         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1261     rewriter.setInsertionPointToStart(bodyBlock);
1262 
1263     // Copy size from shape to descriptor.
1264     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1265     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1266         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1267     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1268     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1269                                       targetSizesBase, indexArg, size);
1270 
1271     // Write stride value and compute next one.
1272     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1273                                         targetStridesBase, indexArg, strideArg);
1274     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1275 
1276     // Decrement loop counter and branch back.
1277     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1278     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1279                                 condBlock);
1280 
1281     Block *remainder =
1282         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1283 
1284     // Hook up the cond exit to the remainder.
1285     rewriter.setInsertionPointToEnd(condBlock);
1286     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1287                                     llvm::None);
1288 
1289     // Reset position to beginning of new remainder block.
1290     rewriter.setInsertionPointToStart(remainder);
1291 
1292     *descriptor = targetDesc;
1293     return success();
1294   }
1295 };
1296 
1297 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1298 /// `Value`s.
1299 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1300                                       Type &llvmIndexType,
1301                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1302   return llvm::to_vector<4>(
1303       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1304         if (auto attr = value.dyn_cast<Attribute>())
1305           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1306         return value.get<Value>();
1307       }));
1308 }
1309 
1310 /// Compute a map that for a given dimension of the expanded type gives the
1311 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1312 /// the `reassocation` maps.
1313 static DenseMap<int64_t, int64_t>
1314 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1315   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1316   for (auto &en : enumerate(reassociation)) {
1317     for (auto dim : en.value())
1318       expandedDimToCollapsedDim[dim] = en.index();
1319   }
1320   return expandedDimToCollapsedDim;
1321 }
1322 
1323 static OpFoldResult
1324 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1325                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1326                          MemRefDescriptor &inDesc,
1327                          ArrayRef<int64_t> inStaticShape,
1328                          ArrayRef<ReassociationIndices> reassocation,
1329                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1330   int64_t outDimSize = outStaticShape[outDimIndex];
1331   if (!ShapedType::isDynamic(outDimSize))
1332     return b.getIndexAttr(outDimSize);
1333 
1334   // Calculate the multiplication of all the out dim sizes except the
1335   // current dim.
1336   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1337   int64_t otherDimSizesMul = 1;
1338   for (auto otherDimIndex : reassocation[inDimIndex]) {
1339     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1340       continue;
1341     int64_t otherDimSize = outStaticShape[otherDimIndex];
1342     assert(!ShapedType::isDynamic(otherDimSize) &&
1343            "single dimension cannot be expanded into multiple dynamic "
1344            "dimensions");
1345     otherDimSizesMul *= otherDimSize;
1346   }
1347 
1348   // outDimSize = inDimSize / otherOutDimSizesMul
1349   int64_t inDimSize = inStaticShape[inDimIndex];
1350   Value inDimSizeDynamic =
1351       ShapedType::isDynamic(inDimSize)
1352           ? inDesc.size(b, loc, inDimIndex)
1353           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1354                                        b.getIndexAttr(inDimSize));
1355   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1356       loc, inDimSizeDynamic,
1357       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1358                                  b.getIndexAttr(otherDimSizesMul)));
1359   return outDimSizeDynamic;
1360 }
1361 
1362 static OpFoldResult getCollapsedOutputDimSize(
1363     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1364     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1365     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1366   if (!ShapedType::isDynamic(outDimSize))
1367     return b.getIndexAttr(outDimSize);
1368 
1369   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1370   Value outDimSizeDynamic = c1;
1371   for (auto inDimIndex : reassocation[outDimIndex]) {
1372     int64_t inDimSize = inStaticShape[inDimIndex];
1373     Value inDimSizeDynamic =
1374         ShapedType::isDynamic(inDimSize)
1375             ? inDesc.size(b, loc, inDimIndex)
1376             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1377                                          b.getIndexAttr(inDimSize));
1378     outDimSizeDynamic =
1379         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1380   }
1381   return outDimSizeDynamic;
1382 }
1383 
1384 static SmallVector<OpFoldResult, 4>
1385 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1386                         ArrayRef<ReassociationIndices> reassociation,
1387                         ArrayRef<int64_t> inStaticShape,
1388                         MemRefDescriptor &inDesc,
1389                         ArrayRef<int64_t> outStaticShape) {
1390   return llvm::to_vector<4>(llvm::map_range(
1391       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1392         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1393                                          outStaticShape[outDimIndex],
1394                                          inStaticShape, inDesc, reassociation);
1395       }));
1396 }
1397 
1398 static SmallVector<OpFoldResult, 4>
1399 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1400                        ArrayRef<ReassociationIndices> reassociation,
1401                        ArrayRef<int64_t> inStaticShape,
1402                        MemRefDescriptor &inDesc,
1403                        ArrayRef<int64_t> outStaticShape) {
1404   DenseMap<int64_t, int64_t> outDimToInDimMap =
1405       getExpandedDimToCollapsedDimMap(reassociation);
1406   return llvm::to_vector<4>(llvm::map_range(
1407       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1408         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1409                                         outStaticShape, inDesc, inStaticShape,
1410                                         reassociation, outDimToInDimMap);
1411       }));
1412 }
1413 
1414 static SmallVector<Value>
1415 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1416                       ArrayRef<ReassociationIndices> reassociation,
1417                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1418                       ArrayRef<int64_t> outStaticShape) {
1419   return outStaticShape.size() < inStaticShape.size()
1420              ? getAsValues(b, loc, llvmIndexType,
1421                            getCollapsedOutputShape(b, loc, llvmIndexType,
1422                                                    reassociation, inStaticShape,
1423                                                    inDesc, outStaticShape))
1424              : getAsValues(b, loc, llvmIndexType,
1425                            getExpandedOutputShape(b, loc, llvmIndexType,
1426                                                   reassociation, inStaticShape,
1427                                                   inDesc, outStaticShape));
1428 }
1429 
1430 static void fillInStridesForExpandedMemDescriptor(
1431     OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
1432     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1433   // See comments for computeExpandedLayoutMap for details on how the strides
1434   // are calculated.
1435   for (auto &en : llvm::enumerate(reassociation)) {
1436     auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
1437     for (auto dstIndex : llvm::reverse(en.value())) {
1438       dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
1439       Value size = dstDesc.size(b, loc, dstIndex);
1440       currentStrideToExpand =
1441           b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
1442     }
1443   }
1444 }
1445 
1446 static void fillInStridesForCollapsedMemDescriptor(
1447     ConversionPatternRewriter &rewriter, Location loc, Operation *op,
1448     TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
1449     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1450   // See comments for computeCollapsedLayoutMap for details on how the strides
1451   // are calculated.
1452   auto srcShape = srcType.getShape();
1453   for (auto &en : llvm::enumerate(reassociation)) {
1454     rewriter.setInsertionPoint(op);
1455     auto dstIndex = en.index();
1456     ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value());
1457     while (srcShape[ref.back()] == 1 && ref.size() > 1)
1458       ref = ref.drop_back();
1459     if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1460       dstDesc.setStride(rewriter, loc, dstIndex,
1461                         srcDesc.stride(rewriter, loc, ref.back()));
1462     } else {
1463       // Iterate over the source strides in reverse order. Skip over the
1464       // dimensions whose size is 1.
1465       // TODO: we should take the minimum stride in the reassociation group
1466       // instead of just the first where the dimension is not 1.
1467       //
1468       // +------------------------------------------------------+
1469       // | curEntry:                                            |
1470       // |   %srcStride = strides[srcIndex]                     |
1471       // |   %neOne = cmp sizes[srcIndex],1                     +--+
1472       // |   cf.cond_br %neOne, continue(%srcStride), nextEntry |  |
1473       // +-------------------------+----------------------------+  |
1474       //                           |                               |
1475       //                           v                               |
1476       //            +-----------------------------+                |
1477       //            | nextEntry:                  |                |
1478       //            |   ...                       +---+            |
1479       //            +--------------+--------------+   |            |
1480       //                           |                  |            |
1481       //                           v                  |            |
1482       //            +-----------------------------+   |            |
1483       //            | nextEntry:                  |   |            |
1484       //            |   ...                       |   |            |
1485       //            +--------------+--------------+   |   +--------+
1486       //                           |                  |   |
1487       //                           v                  v   v
1488       //   +--------------------------------------------------+
1489       //   | continue(%newStride):                            |
1490       //   |   %newMemRefDes = setStride(%newStride,dstIndex) |
1491       //   +--------------------------------------------------+
1492       OpBuilder::InsertionGuard guard(rewriter);
1493       Block *initBlock = rewriter.getInsertionBlock();
1494       Block *continueBlock =
1495           rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
1496       continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
1497       rewriter.setInsertionPointToStart(continueBlock);
1498       dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
1499 
1500       Block *curEntryBlock = initBlock;
1501       Block *nextEntryBlock;
1502       for (auto srcIndex : llvm::reverse(ref)) {
1503         if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
1504           continue;
1505         rewriter.setInsertionPointToEnd(curEntryBlock);
1506         Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
1507         if (srcIndex == ref.front()) {
1508           rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
1509           break;
1510         }
1511         Value one = rewriter.create<LLVM::ConstantOp>(
1512             loc, typeConverter->convertType(rewriter.getI64Type()),
1513             rewriter.getI32IntegerAttr(1));
1514         Value predNeOne = rewriter.create<LLVM::ICmpOp>(
1515             loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
1516             one);
1517         {
1518           OpBuilder::InsertionGuard guard(rewriter);
1519           nextEntryBlock = rewriter.createBlock(
1520               initBlock->getParent(), Region::iterator(continueBlock), {});
1521         }
1522         rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
1523                                         srcStride, nextEntryBlock, llvm::None);
1524         curEntryBlock = nextEntryBlock;
1525       }
1526     }
1527   }
1528 }
1529 
1530 static void fillInDynamicStridesForMemDescriptor(
1531     ConversionPatternRewriter &b, Location loc, Operation *op,
1532     TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
1533     MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
1534     ArrayRef<ReassociationIndices> reassociation) {
1535   if (srcType.getRank() > dstType.getRank())
1536     fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
1537                                            srcDesc, dstDesc, reassociation);
1538   else
1539     fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
1540                                           reassociation);
1541 }
1542 
1543 // ReshapeOp creates a new view descriptor of the proper rank.
1544 // For now, the only conversion supported is for target MemRef with static sizes
1545 // and strides.
1546 template <typename ReshapeOp>
1547 class ReassociatingReshapeOpConversion
1548     : public ConvertOpToLLVMPattern<ReshapeOp> {
1549 public:
1550   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1551   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1552 
1553   LogicalResult
1554   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1555                   ConversionPatternRewriter &rewriter) const override {
1556     MemRefType dstType = reshapeOp.getResultType();
1557     MemRefType srcType = reshapeOp.getSrcType();
1558 
1559     int64_t offset;
1560     SmallVector<int64_t, 4> strides;
1561     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1562       return rewriter.notifyMatchFailure(
1563           reshapeOp, "failed to get stride and offset exprs");
1564     }
1565 
1566     MemRefDescriptor srcDesc(adaptor.getSrc());
1567     Location loc = reshapeOp->getLoc();
1568     auto dstDesc = MemRefDescriptor::undef(
1569         rewriter, loc, this->typeConverter->convertType(dstType));
1570     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1571     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1572     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1573 
1574     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1575     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1576     Type llvmIndexType =
1577         this->typeConverter->convertType(rewriter.getIndexType());
1578     SmallVector<Value> dstShape = getDynamicOutputShape(
1579         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1580         srcStaticShape, srcDesc, dstStaticShape);
1581     for (auto &en : llvm::enumerate(dstShape))
1582       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1583 
1584     if (llvm::all_of(strides, isStaticStrideOrOffset)) {
1585       for (auto &en : llvm::enumerate(strides))
1586         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1587     } else if (srcType.getLayout().isIdentity() &&
1588                dstType.getLayout().isIdentity()) {
1589       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1590                                                    rewriter.getIndexAttr(1));
1591       Value stride = c1;
1592       for (auto dimIndex :
1593            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1594         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1595         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1596       }
1597     } else {
1598       // There could be mixed static/dynamic strides. For simplicity, we
1599       // recompute all strides if there is at least one dynamic stride.
1600       fillInDynamicStridesForMemDescriptor(
1601           rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
1602           srcDesc, dstDesc, reshapeOp.getReassociationIndices());
1603     }
1604     rewriter.replaceOp(reshapeOp, {dstDesc});
1605     return success();
1606   }
1607 };
1608 
1609 /// Conversion pattern that transforms a subview op into:
1610 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1611 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1612 ///      and stride.
1613 /// The subview op is replaced by the descriptor.
1614 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1615   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1616 
1617   LogicalResult
1618   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1619                   ConversionPatternRewriter &rewriter) const override {
1620     auto loc = subViewOp.getLoc();
1621 
1622     auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>();
1623     auto sourceElementTy =
1624         typeConverter->convertType(sourceMemRefType.getElementType());
1625 
1626     auto viewMemRefType = subViewOp.getType();
1627     auto inferredType =
1628         memref::SubViewOp::inferResultType(
1629             subViewOp.getSourceType(),
1630             extractFromI64ArrayAttr(subViewOp.getStaticOffsets()),
1631             extractFromI64ArrayAttr(subViewOp.getStaticSizes()),
1632             extractFromI64ArrayAttr(subViewOp.getStaticStrides()))
1633             .cast<MemRefType>();
1634     auto targetElementTy =
1635         typeConverter->convertType(viewMemRefType.getElementType());
1636     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1637     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1638         !LLVM::isCompatibleType(sourceElementTy) ||
1639         !LLVM::isCompatibleType(targetElementTy) ||
1640         !LLVM::isCompatibleType(targetDescTy))
1641       return failure();
1642 
1643     // Extract the offset and strides from the type.
1644     int64_t offset;
1645     SmallVector<int64_t, 4> strides;
1646     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1647     if (failed(successStrides))
1648       return failure();
1649 
1650     // Create the descriptor.
1651     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1652       return failure();
1653     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1654     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1655 
1656     // Copy the buffer pointer from the old descriptor to the new one.
1657     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1658     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1659         loc,
1660         LLVM::LLVMPointerType::get(targetElementTy,
1661                                    viewMemRefType.getMemorySpaceAsInt()),
1662         extracted);
1663     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1664 
1665     // Copy the aligned pointer from the old descriptor to the new one.
1666     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1667     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1668         loc,
1669         LLVM::LLVMPointerType::get(targetElementTy,
1670                                    viewMemRefType.getMemorySpaceAsInt()),
1671         extracted);
1672     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1673 
1674     size_t inferredShapeRank = inferredType.getRank();
1675     size_t resultShapeRank = viewMemRefType.getRank();
1676 
1677     // Extract strides needed to compute offset.
1678     SmallVector<Value, 4> strideValues;
1679     strideValues.reserve(inferredShapeRank);
1680     for (unsigned i = 0; i < inferredShapeRank; ++i)
1681       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1682 
1683     // Offset.
1684     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1685     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1686       targetMemRef.setConstantOffset(rewriter, loc, offset);
1687     } else {
1688       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1689       // `inferredShapeRank` may be larger than the number of offset operands
1690       // because of trailing semantics. In this case, the offset is guaranteed
1691       // to be interpreted as 0 and we can just skip the extra dimensions.
1692       for (unsigned i = 0, e = std::min(inferredShapeRank,
1693                                         subViewOp.getMixedOffsets().size());
1694            i < e; ++i) {
1695         Value offset =
1696             // TODO: need OpFoldResult ODS adaptor to clean this up.
1697             subViewOp.isDynamicOffset(i)
1698                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1699                 : rewriter.create<LLVM::ConstantOp>(
1700                       loc, llvmIndexType,
1701                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1702         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1703         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1704       }
1705       targetMemRef.setOffset(rewriter, loc, baseOffset);
1706     }
1707 
1708     // Update sizes and strides.
1709     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1710     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1711     assert(mixedSizes.size() == mixedStrides.size() &&
1712            "expected sizes and strides of equal length");
1713     llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1714     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1715          i >= 0 && j >= 0; --i) {
1716       if (unusedDims.test(i))
1717         continue;
1718 
1719       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1720       // In this case, the size is guaranteed to be interpreted as Dim and the
1721       // stride as 1.
1722       Value size, stride;
1723       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1724         // If the static size is available, use it directly. This is similar to
1725         // the folding of dim(constant-op) but removes the need for dim to be
1726         // aware of LLVM constants and for this pass to be aware of std
1727         // constants.
1728         int64_t staticSize =
1729             subViewOp.getSource().getType().cast<MemRefType>().getShape()[i];
1730         if (staticSize != ShapedType::kDynamicSize) {
1731           size = rewriter.create<LLVM::ConstantOp>(
1732               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1733         } else {
1734           Value pos = rewriter.create<LLVM::ConstantOp>(
1735               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1736           Value dim =
1737               rewriter.create<memref::DimOp>(loc, subViewOp.getSource(), pos);
1738           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1739               loc, llvmIndexType, dim);
1740           size = cast.getResult(0);
1741         }
1742         stride = rewriter.create<LLVM::ConstantOp>(
1743             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1744       } else {
1745         // TODO: need OpFoldResult ODS adaptor to clean this up.
1746         size =
1747             subViewOp.isDynamicSize(i)
1748                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1749                 : rewriter.create<LLVM::ConstantOp>(
1750                       loc, llvmIndexType,
1751                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1752         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1753           stride = rewriter.create<LLVM::ConstantOp>(
1754               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1755         } else {
1756           stride =
1757               subViewOp.isDynamicStride(i)
1758                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1759                   : rewriter.create<LLVM::ConstantOp>(
1760                         loc, llvmIndexType,
1761                         rewriter.getI64IntegerAttr(
1762                             subViewOp.getStaticStride(i)));
1763           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1764         }
1765       }
1766       targetMemRef.setSize(rewriter, loc, j, size);
1767       targetMemRef.setStride(rewriter, loc, j, stride);
1768       j--;
1769     }
1770 
1771     rewriter.replaceOp(subViewOp, {targetMemRef});
1772     return success();
1773   }
1774 };
1775 
1776 /// Conversion pattern that transforms a transpose op into:
1777 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1778 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1779 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1780 ///      and stride. Size and stride are permutations of the original values.
1781 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1782 /// The transpose op is replaced by the alloca'ed pointer.
1783 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1784 public:
1785   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1786 
1787   LogicalResult
1788   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1789                   ConversionPatternRewriter &rewriter) const override {
1790     auto loc = transposeOp.getLoc();
1791     MemRefDescriptor viewMemRef(adaptor.getIn());
1792 
1793     // No permutation, early exit.
1794     if (transposeOp.getPermutation().isIdentity())
1795       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1796 
1797     auto targetMemRef = MemRefDescriptor::undef(
1798         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1799 
1800     // Copy the base and aligned pointers from the old descriptor to the new
1801     // one.
1802     targetMemRef.setAllocatedPtr(rewriter, loc,
1803                                  viewMemRef.allocatedPtr(rewriter, loc));
1804     targetMemRef.setAlignedPtr(rewriter, loc,
1805                                viewMemRef.alignedPtr(rewriter, loc));
1806 
1807     // Copy the offset pointer from the old descriptor to the new one.
1808     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1809 
1810     // Iterate over the dimensions and apply size/stride permutation.
1811     for (const auto &en :
1812          llvm::enumerate(transposeOp.getPermutation().getResults())) {
1813       int sourcePos = en.index();
1814       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1815       targetMemRef.setSize(rewriter, loc, targetPos,
1816                            viewMemRef.size(rewriter, loc, sourcePos));
1817       targetMemRef.setStride(rewriter, loc, targetPos,
1818                              viewMemRef.stride(rewriter, loc, sourcePos));
1819     }
1820 
1821     rewriter.replaceOp(transposeOp, {targetMemRef});
1822     return success();
1823   }
1824 };
1825 
1826 /// Conversion pattern that transforms an op into:
1827 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1828 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1829 ///      and stride.
1830 /// The view op is replaced by the descriptor.
1831 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1832   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1833 
1834   // Build and return the value for the idx^th shape dimension, either by
1835   // returning the constant shape dimension or counting the proper dynamic size.
1836   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1837                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1838                 unsigned idx) const {
1839     assert(idx < shape.size());
1840     if (!ShapedType::isDynamic(shape[idx]))
1841       return createIndexConstant(rewriter, loc, shape[idx]);
1842     // Count the number of dynamic dims in range [0, idx]
1843     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1844       return ShapedType::isDynamic(v);
1845     });
1846     return dynamicSizes[nDynamic];
1847   }
1848 
1849   // Build and return the idx^th stride, either by returning the constant stride
1850   // or by computing the dynamic stride from the current `runningStride` and
1851   // `nextSize`. The caller should keep a running stride and update it with the
1852   // result returned by this function.
1853   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1854                   ArrayRef<int64_t> strides, Value nextSize,
1855                   Value runningStride, unsigned idx) const {
1856     assert(idx < strides.size());
1857     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1858       return createIndexConstant(rewriter, loc, strides[idx]);
1859     if (nextSize)
1860       return runningStride
1861                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1862                  : nextSize;
1863     assert(!runningStride);
1864     return createIndexConstant(rewriter, loc, 1);
1865   }
1866 
1867   LogicalResult
1868   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1869                   ConversionPatternRewriter &rewriter) const override {
1870     auto loc = viewOp.getLoc();
1871 
1872     auto viewMemRefType = viewOp.getType();
1873     auto targetElementTy =
1874         typeConverter->convertType(viewMemRefType.getElementType());
1875     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1876     if (!targetDescTy || !targetElementTy ||
1877         !LLVM::isCompatibleType(targetElementTy) ||
1878         !LLVM::isCompatibleType(targetDescTy))
1879       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1880              failure();
1881 
1882     int64_t offset;
1883     SmallVector<int64_t, 4> strides;
1884     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1885     if (failed(successStrides))
1886       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1887     assert(offset == 0 && "expected offset to be 0");
1888 
1889     // Target memref must be contiguous in memory (innermost stride is 1), or
1890     // empty (special case when at least one of the memref dimensions is 0).
1891     if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1892       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1893              failure();
1894 
1895     // Create the descriptor.
1896     MemRefDescriptor sourceMemRef(adaptor.getSource());
1897     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1898 
1899     // Field 1: Copy the allocated pointer, used for malloc/free.
1900     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1901     auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
1902     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1903         loc,
1904         LLVM::LLVMPointerType::get(targetElementTy,
1905                                    srcMemRefType.getMemorySpaceAsInt()),
1906         allocatedPtr);
1907     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1908 
1909     // Field 2: Copy the actual aligned pointer to payload.
1910     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1911     alignedPtr = rewriter.create<LLVM::GEPOp>(
1912         loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift());
1913     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1914         loc,
1915         LLVM::LLVMPointerType::get(targetElementTy,
1916                                    srcMemRefType.getMemorySpaceAsInt()),
1917         alignedPtr);
1918     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1919 
1920     // Field 3: The offset in the resulting type must be 0. This is because of
1921     // the type change: an offset on srcType* may not be expressible as an
1922     // offset on dstType*.
1923     targetMemRef.setOffset(rewriter, loc,
1924                            createIndexConstant(rewriter, loc, offset));
1925 
1926     // Early exit for 0-D corner case.
1927     if (viewMemRefType.getRank() == 0)
1928       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1929 
1930     // Fields 4 and 5: Update sizes and strides.
1931     Value stride = nullptr, nextSize = nullptr;
1932     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1933       // Update size.
1934       Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1935                            adaptor.getSizes(), i);
1936       targetMemRef.setSize(rewriter, loc, i, size);
1937       // Update stride.
1938       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1939       targetMemRef.setStride(rewriter, loc, i, stride);
1940       nextSize = size;
1941     }
1942 
1943     rewriter.replaceOp(viewOp, {targetMemRef});
1944     return success();
1945   }
1946 };
1947 
1948 //===----------------------------------------------------------------------===//
1949 // AtomicRMWOpLowering
1950 //===----------------------------------------------------------------------===//
1951 
1952 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1953 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1954 static Optional<LLVM::AtomicBinOp>
1955 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1956   switch (atomicOp.getKind()) {
1957   case arith::AtomicRMWKind::addf:
1958     return LLVM::AtomicBinOp::fadd;
1959   case arith::AtomicRMWKind::addi:
1960     return LLVM::AtomicBinOp::add;
1961   case arith::AtomicRMWKind::assign:
1962     return LLVM::AtomicBinOp::xchg;
1963   case arith::AtomicRMWKind::maxs:
1964     return LLVM::AtomicBinOp::max;
1965   case arith::AtomicRMWKind::maxu:
1966     return LLVM::AtomicBinOp::umax;
1967   case arith::AtomicRMWKind::mins:
1968     return LLVM::AtomicBinOp::min;
1969   case arith::AtomicRMWKind::minu:
1970     return LLVM::AtomicBinOp::umin;
1971   case arith::AtomicRMWKind::ori:
1972     return LLVM::AtomicBinOp::_or;
1973   case arith::AtomicRMWKind::andi:
1974     return LLVM::AtomicBinOp::_and;
1975   default:
1976     return llvm::None;
1977   }
1978   llvm_unreachable("Invalid AtomicRMWKind");
1979 }
1980 
1981 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1982   using Base::Base;
1983 
1984   LogicalResult
1985   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1986                   ConversionPatternRewriter &rewriter) const override {
1987     if (failed(match(atomicOp)))
1988       return failure();
1989     auto maybeKind = matchSimpleAtomicOp(atomicOp);
1990     if (!maybeKind)
1991       return failure();
1992     auto resultType = adaptor.getValue().getType();
1993     auto memRefType = atomicOp.getMemRefType();
1994     auto dataPtr =
1995         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1996                              adaptor.getIndices(), rewriter);
1997     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1998         atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
1999         LLVM::AtomicOrdering::acq_rel);
2000     return success();
2001   }
2002 };
2003 
2004 } // namespace
2005 
2006 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
2007                                                   RewritePatternSet &patterns) {
2008   // clang-format off
2009   patterns.add<
2010       AllocaOpLowering,
2011       AllocaScopeOpLowering,
2012       AtomicRMWOpLowering,
2013       AssumeAlignmentOpLowering,
2014       DimOpLowering,
2015       GenericAtomicRMWOpLowering,
2016       GlobalMemrefOpLowering,
2017       GetGlobalMemrefOpLowering,
2018       LoadOpLowering,
2019       MemRefCastOpLowering,
2020       MemRefCopyOpLowering,
2021       MemRefReinterpretCastOpLowering,
2022       MemRefReshapeOpLowering,
2023       PrefetchOpLowering,
2024       RankOpLowering,
2025       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2026       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2027       StoreOpLowering,
2028       SubViewOpLowering,
2029       TransposeOpLowering,
2030       ViewOpLowering>(converter);
2031   // clang-format on
2032   auto allocLowering = converter.getOptions().allocLowering;
2033   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
2034     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
2035   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2036     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
2037 }
2038 
2039 namespace {
2040 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
2041   MemRefToLLVMPass() = default;
2042 
2043   void runOnOperation() override {
2044     Operation *op = getOperation();
2045     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2046     LowerToLLVMOptions options(&getContext(),
2047                                dataLayoutAnalysis.getAtOrAbove(op));
2048     options.allocLowering =
2049         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
2050                          : LowerToLLVMOptions::AllocLowering::Malloc);
2051     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2052       options.overrideIndexBitwidth(indexBitwidth);
2053 
2054     LLVMTypeConverter typeConverter(&getContext(), options,
2055                                     &dataLayoutAnalysis);
2056     RewritePatternSet patterns(&getContext());
2057     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
2058     LLVMConversionTarget target(getContext());
2059     target.addLegalOp<func::FuncOp>();
2060     if (failed(applyPartialConversion(op, target, std::move(patterns))))
2061       signalPassFailure();
2062   }
2063 };
2064 } // namespace
2065 
2066 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
2067   return std::make_unique<MemRefToLLVMPass>();
2068 }
2069