1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 
29 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
30   return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
31 }
32 
33 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
34   AllocOpLowering(LLVMTypeConverter &converter)
35       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
36                                 converter) {}
37 
38   LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
39     bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
40 
41     if (useGenericFn)
42       return LLVM::lookupOrCreateGenericAllocFn(module, getIndexType());
43 
44     return LLVM::lookupOrCreateMallocFn(module, getIndexType());
45   }
46 
47   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
48                                           Location loc, Value sizeBytes,
49                                           Operation *op) const override {
50     // Heap allocations.
51     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
52     MemRefType memRefType = allocOp.getType();
53 
54     Value alignment;
55     if (auto alignmentAttr = allocOp.getAlignment()) {
56       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
57     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
58       // In the case where no alignment is specified, we may want to override
59       // `malloc's` behavior. `malloc` typically aligns at the size of the
60       // biggest scalar on a target HW. For non-scalars, use the natural
61       // alignment of the LLVM type given by the LLVM DataLayout.
62       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
63     }
64 
65     if (alignment) {
66       // Adjust the allocation size to consider alignment.
67       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
68     }
69 
70     // Allocate the underlying buffer and store a pointer to it in the MemRef
71     // descriptor.
72     Type elementPtrType = this->getElementPtrType(memRefType);
73     auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
74     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
75                                   getVoidPtrType());
76     Value allocatedPtr =
77         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
78 
79     Value alignedPtr = allocatedPtr;
80     if (alignment) {
81       // Compute the aligned type pointer.
82       Value allocatedInt =
83           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
84       Value alignmentInt =
85           createAligned(rewriter, loc, allocatedInt, alignment);
86       alignedPtr =
87           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
88     }
89 
90     return std::make_tuple(allocatedPtr, alignedPtr);
91   }
92 };
93 
94 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
95   AlignedAllocOpLowering(LLVMTypeConverter &converter)
96       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
97                                 converter) {}
98 
99   /// Returns the memref's element size in bytes using the data layout active at
100   /// `op`.
101   // TODO: there are other places where this is used. Expose publicly?
102   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
103     const DataLayout *layout = &defaultLayout;
104     if (const DataLayoutAnalysis *analysis =
105             getTypeConverter()->getDataLayoutAnalysis()) {
106       layout = &analysis->getAbove(op);
107     }
108     Type elementType = memRefType.getElementType();
109     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
110       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
111                                                          *layout);
112     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
113       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
114           memRefElementType, *layout);
115     return layout->getTypeSize(elementType);
116   }
117 
118   /// Returns true if the memref size in bytes is known to be a multiple of
119   /// factor assuming the data layout active at `op`.
120   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
121                               Operation *op) const {
122     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
123     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
124       if (ShapedType::isDynamic(type.getDimSize(i)))
125         continue;
126       sizeDivisor = sizeDivisor * type.getDimSize(i);
127     }
128     return sizeDivisor % factor == 0;
129   }
130 
131   /// Returns the alignment to be used for the allocation call itself.
132   /// aligned_alloc requires the allocation size to be a power of two, and the
133   /// allocation size to be a multiple of alignment,
134   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
135     if (Optional<uint64_t> alignment = allocOp.getAlignment())
136       return *alignment;
137 
138     // Whenever we don't have alignment set, we will use an alignment
139     // consistent with the element type; since the allocation size has to be a
140     // power of two, we will bump to the next power of two if it already isn't.
141     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
142     return std::max(kMinAlignedAllocAlignment,
143                     llvm::PowerOf2Ceil(eltSizeBytes));
144   }
145 
146   LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
147     bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
148 
149     if (useGenericFn)
150       return LLVM::lookupOrCreateGenericAlignedAllocFn(module, getIndexType());
151 
152     return LLVM::lookupOrCreateAlignedAllocFn(module, getIndexType());
153   }
154 
155   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
156                                           Location loc, Value sizeBytes,
157                                           Operation *op) const override {
158     // Heap allocations.
159     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
160     MemRefType memRefType = allocOp.getType();
161     int64_t alignment = getAllocationAlignment(allocOp);
162     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
163 
164     // aligned_alloc requires size to be a multiple of alignment; we will pad
165     // the size to the next multiple if necessary.
166     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
167       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
168 
169     Type elementPtrType = this->getElementPtrType(memRefType);
170     auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
171     auto results =
172         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
173                        getVoidPtrType());
174     Value allocatedPtr =
175         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
176 
177     return std::make_tuple(allocatedPtr, allocatedPtr);
178   }
179 
180   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
181   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
182 
183   /// Default layout to use in absence of the corresponding analysis.
184   DataLayout defaultLayout;
185 };
186 
187 // Out of line definition, required till C++17.
188 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
189 
190 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
191   AllocaOpLowering(LLVMTypeConverter &converter)
192       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
193                                 converter) {}
194 
195   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
196   /// is set to null for stack allocations. `accessAlignment` is set if
197   /// alignment is needed post allocation (for eg. in conjunction with malloc).
198   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
199                                           Location loc, Value sizeBytes,
200                                           Operation *op) const override {
201 
202     // With alloca, one gets a pointer to the element type right away.
203     // For stack allocations.
204     auto allocaOp = cast<memref::AllocaOp>(op);
205     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
206 
207     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
208         loc, elementPtrType, sizeBytes,
209         allocaOp.getAlignment() ? *allocaOp.getAlignment() : 0);
210 
211     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
212   }
213 };
214 
215 struct AllocaScopeOpLowering
216     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
217   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
218 
219   LogicalResult
220   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
221                   ConversionPatternRewriter &rewriter) const override {
222     OpBuilder::InsertionGuard guard(rewriter);
223     Location loc = allocaScopeOp.getLoc();
224 
225     // Split the current block before the AllocaScopeOp to create the inlining
226     // point.
227     auto *currentBlock = rewriter.getInsertionBlock();
228     auto *remainingOpsBlock =
229         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
230     Block *continueBlock;
231     if (allocaScopeOp.getNumResults() == 0) {
232       continueBlock = remainingOpsBlock;
233     } else {
234       continueBlock = rewriter.createBlock(
235           remainingOpsBlock, allocaScopeOp.getResultTypes(),
236           SmallVector<Location>(allocaScopeOp->getNumResults(),
237                                 allocaScopeOp.getLoc()));
238       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
239     }
240 
241     // Inline body region.
242     Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
243     Block *afterBody = &allocaScopeOp.getBodyRegion().back();
244     rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
245 
246     // Save stack and then branch into the body of the region.
247     rewriter.setInsertionPointToEnd(currentBlock);
248     auto stackSaveOp =
249         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
250     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
251 
252     // Replace the alloca_scope return with a branch that jumps out of the body.
253     // Stack restore before leaving the body region.
254     rewriter.setInsertionPointToEnd(afterBody);
255     auto returnOp =
256         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
257     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
258         returnOp, returnOp.getResults(), continueBlock);
259 
260     // Insert stack restore before jumping out the body of the region.
261     rewriter.setInsertionPoint(branchOp);
262     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
263 
264     // Replace the op with values return from the body region.
265     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
266 
267     return success();
268   }
269 };
270 
271 struct AssumeAlignmentOpLowering
272     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
273   using ConvertOpToLLVMPattern<
274       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
275 
276   LogicalResult
277   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
278                   ConversionPatternRewriter &rewriter) const override {
279     Value memref = adaptor.getMemref();
280     unsigned alignment = op.getAlignment();
281     auto loc = op.getLoc();
282 
283     MemRefDescriptor memRefDescriptor(memref);
284     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
285 
286     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
287     // the asserted memref.alignedPtr isn't used anywhere else, as the real
288     // users like load/store/views always re-extract memref.alignedPtr as they
289     // get lowered.
290     //
291     // This relies on LLVM's CSE optimization (potentially after SROA), since
292     // after CSE all memref.alignedPtr instances get de-duplicated into the same
293     // pointer SSA value.
294     auto intPtrType =
295         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
296     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
297     Value mask =
298         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
299     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
300     rewriter.create<LLVM::AssumeOp>(
301         loc, rewriter.create<LLVM::ICmpOp>(
302                  loc, LLVM::ICmpPredicate::eq,
303                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
304 
305     rewriter.eraseOp(op);
306     return success();
307   }
308 };
309 
310 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
311 // The memref descriptor being an SSA value, there is no need to clean it up
312 // in any way.
313 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
314   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
315 
316   explicit DeallocOpLowering(LLVMTypeConverter &converter)
317       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
318 
319   LLVM::LLVMFuncOp getFreeFn(ModuleOp module) const {
320     bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
321 
322     if (useGenericFn)
323       return LLVM::lookupOrCreateGenericFreeFn(module);
324 
325     return LLVM::lookupOrCreateFreeFn(module);
326   }
327 
328   LogicalResult
329   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
330                   ConversionPatternRewriter &rewriter) const override {
331     // Insert the `free` declaration if it is not already present.
332     auto freeFunc = getFreeFn(op->getParentOfType<ModuleOp>());
333     MemRefDescriptor memref(adaptor.getMemref());
334     Value casted = rewriter.create<LLVM::BitcastOp>(
335         op.getLoc(), getVoidPtrType(),
336         memref.allocatedPtr(rewriter, op.getLoc()));
337     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
338         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
339     return success();
340   }
341 };
342 
343 // A `dim` is converted to a constant for static sizes and to an access to the
344 // size stored in the memref descriptor for dynamic sizes.
345 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
346   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
347 
348   LogicalResult
349   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
350                   ConversionPatternRewriter &rewriter) const override {
351     Type operandType = dimOp.getSource().getType();
352     if (operandType.isa<UnrankedMemRefType>()) {
353       rewriter.replaceOp(
354           dimOp, {extractSizeOfUnrankedMemRef(
355                      operandType, dimOp, adaptor.getOperands(), rewriter)});
356 
357       return success();
358     }
359     if (operandType.isa<MemRefType>()) {
360       rewriter.replaceOp(
361           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
362                                             adaptor.getOperands(), rewriter)});
363       return success();
364     }
365     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
366   }
367 
368 private:
369   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
370                                     OpAdaptor adaptor,
371                                     ConversionPatternRewriter &rewriter) const {
372     Location loc = dimOp.getLoc();
373 
374     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
375     auto scalarMemRefType =
376         MemRefType::get({}, unrankedMemRefType.getElementType());
377     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
378 
379     // Extract pointer to the underlying ranked descriptor and bitcast it to a
380     // memref<element_type> descriptor pointer to minimize the number of GEP
381     // operations.
382     UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
383     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
384     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
385         loc,
386         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
387                                    addressSpace),
388         underlyingRankedDesc);
389 
390     // Get pointer to offset field of memref<element_type> descriptor.
391     Type indexPtrTy = LLVM::LLVMPointerType::get(
392         getTypeConverter()->getIndexType(), addressSpace);
393     Value two = rewriter.create<LLVM::ConstantOp>(
394         loc, typeConverter->convertType(rewriter.getI32Type()),
395         rewriter.getI32IntegerAttr(2));
396     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
397         loc, indexPtrTy, scalarMemRefDescPtr,
398         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
399 
400     // The size value that we have to extract can be obtained using GEPop with
401     // `dimOp.index() + 1` index argument.
402     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
403         loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
404     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
405                                                  ValueRange({idxPlusOne}));
406     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
407   }
408 
409   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
410     if (Optional<int64_t> idx = dimOp.getConstantIndex())
411       return idx;
412 
413     if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
414       return constantOp.getValue()
415           .cast<IntegerAttr>()
416           .getValue()
417           .getSExtValue();
418 
419     return llvm::None;
420   }
421 
422   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
423                                   OpAdaptor adaptor,
424                                   ConversionPatternRewriter &rewriter) const {
425     Location loc = dimOp.getLoc();
426 
427     // Take advantage if index is constant.
428     MemRefType memRefType = operandType.cast<MemRefType>();
429     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
430       int64_t i = *index;
431       if (memRefType.isDynamicDim(i)) {
432         // extract dynamic size from the memref descriptor.
433         MemRefDescriptor descriptor(adaptor.getSource());
434         return descriptor.size(rewriter, loc, i);
435       }
436       // Use constant for static size.
437       int64_t dimSize = memRefType.getDimSize(i);
438       return createIndexConstant(rewriter, loc, dimSize);
439     }
440     Value index = adaptor.getIndex();
441     int64_t rank = memRefType.getRank();
442     MemRefDescriptor memrefDescriptor(adaptor.getSource());
443     return memrefDescriptor.size(rewriter, loc, index, rank);
444   }
445 };
446 
447 /// Common base for load and store operations on MemRefs. Restricts the match
448 /// to supported MemRef types. Provides functionality to emit code accessing a
449 /// specific element of the underlying data buffer.
450 template <typename Derived>
451 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
452   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
453   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
454   using Base = LoadStoreOpLowering<Derived>;
455 
456   LogicalResult match(Derived op) const override {
457     MemRefType type = op.getMemRefType();
458     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
459   }
460 };
461 
462 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
463 /// retried until it succeeds in atomically storing a new value into memory.
464 ///
465 ///      +---------------------------------+
466 ///      |   <code before the AtomicRMWOp> |
467 ///      |   <compute initial %loaded>     |
468 ///      |   cf.br loop(%loaded)              |
469 ///      +---------------------------------+
470 ///             |
471 ///  -------|   |
472 ///  |      v   v
473 ///  |   +--------------------------------+
474 ///  |   | loop(%loaded):                 |
475 ///  |   |   <body contents>              |
476 ///  |   |   %pair = cmpxchg              |
477 ///  |   |   %ok = %pair[0]               |
478 ///  |   |   %new = %pair[1]              |
479 ///  |   |   cf.cond_br %ok, end, loop(%new) |
480 ///  |   +--------------------------------+
481 ///  |          |        |
482 ///  |-----------        |
483 ///                      v
484 ///      +--------------------------------+
485 ///      | end:                           |
486 ///      |   <code after the AtomicRMWOp> |
487 ///      +--------------------------------+
488 ///
489 struct GenericAtomicRMWOpLowering
490     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
491   using Base::Base;
492 
493   LogicalResult
494   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
495                   ConversionPatternRewriter &rewriter) const override {
496     auto loc = atomicOp.getLoc();
497     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
498 
499     // Split the block into initial, loop, and ending parts.
500     auto *initBlock = rewriter.getInsertionBlock();
501     auto *loopBlock = rewriter.createBlock(
502         initBlock->getParent(), std::next(Region::iterator(initBlock)),
503         valueType, loc);
504     auto *endBlock = rewriter.createBlock(
505         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
506 
507     // Operations range to be moved to `endBlock`.
508     auto opsToMoveStart = atomicOp->getIterator();
509     auto opsToMoveEnd = initBlock->back().getIterator();
510 
511     // Compute the loaded value and branch to the loop block.
512     rewriter.setInsertionPointToEnd(initBlock);
513     auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
514     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
515                                         adaptor.getIndices(), rewriter);
516     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
517     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
518 
519     // Prepare the body of the loop block.
520     rewriter.setInsertionPointToStart(loopBlock);
521 
522     // Clone the GenericAtomicRMWOp region and extract the result.
523     auto loopArgument = loopBlock->getArgument(0);
524     BlockAndValueMapping mapping;
525     mapping.map(atomicOp.getCurrentValue(), loopArgument);
526     Block &entryBlock = atomicOp.body().front();
527     for (auto &nestedOp : entryBlock.without_terminator()) {
528       Operation *clone = rewriter.clone(nestedOp, mapping);
529       mapping.map(nestedOp.getResults(), clone->getResults());
530     }
531     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
532 
533     // Prepare the epilog of the loop block.
534     // Append the cmpxchg op to the end of the loop block.
535     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
536     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
537     auto boolType = IntegerType::get(rewriter.getContext(), 1);
538     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
539                                                      {valueType, boolType});
540     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
541         loc, pairType, dataPtr, loopArgument, result, successOrdering,
542         failureOrdering);
543     // Extract the %new_loaded and %ok values from the pair.
544     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
545         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
546     Value ok = rewriter.create<LLVM::ExtractValueOp>(
547         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
548 
549     // Conditionally branch to the end or back to the loop depending on %ok.
550     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
551                                     loopBlock, newLoaded);
552 
553     rewriter.setInsertionPointToEnd(endBlock);
554     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
555                  std::next(opsToMoveEnd), rewriter);
556 
557     // The 'result' of the atomic_rmw op is the newly loaded value.
558     rewriter.replaceOp(atomicOp, {newLoaded});
559 
560     return success();
561   }
562 
563 private:
564   // Clones a segment of ops [start, end) and erases the original.
565   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
566                     Block::iterator start, Block::iterator end,
567                     ConversionPatternRewriter &rewriter) const {
568     BlockAndValueMapping mapping;
569     mapping.map(oldResult, newResult);
570     SmallVector<Operation *, 2> opsToErase;
571     for (auto it = start; it != end; ++it) {
572       rewriter.clone(*it, mapping);
573       opsToErase.push_back(&*it);
574     }
575     for (auto *it : opsToErase)
576       rewriter.eraseOp(it);
577   }
578 };
579 
580 /// Returns the LLVM type of the global variable given the memref type `type`.
581 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
582                                           LLVMTypeConverter &typeConverter) {
583   // LLVM type for a global memref will be a multi-dimension array. For
584   // declarations or uninitialized global memrefs, we can potentially flatten
585   // this to a 1D array. However, for memref.global's with an initial value,
586   // we do not intend to flatten the ElementsAttribute when going from std ->
587   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
588   Type elementType = typeConverter.convertType(type.getElementType());
589   Type arrayTy = elementType;
590   // Shape has the outermost dim at index 0, so need to walk it backwards
591   for (int64_t dim : llvm::reverse(type.getShape()))
592     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
593   return arrayTy;
594 }
595 
596 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
597 struct GlobalMemrefOpLowering
598     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
599   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
600 
601   LogicalResult
602   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
603                   ConversionPatternRewriter &rewriter) const override {
604     MemRefType type = global.getType();
605     if (!isConvertibleAndHasIdentityMaps(type))
606       return failure();
607 
608     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
609 
610     LLVM::Linkage linkage =
611         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
612 
613     Attribute initialValue = nullptr;
614     if (!global.isExternal() && !global.isUninitialized()) {
615       auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
616       initialValue = elementsAttr;
617 
618       // For scalar memrefs, the global variable created is of the element type,
619       // so unpack the elements attribute to extract the value.
620       if (type.getRank() == 0)
621         initialValue = elementsAttr.getSplatValue<Attribute>();
622     }
623 
624     uint64_t alignment = global.getAlignment().value_or(0);
625 
626     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
627         global, arrayTy, global.getConstant(), linkage, global.getSymName(),
628         initialValue, alignment, type.getMemorySpaceAsInt());
629     if (!global.isExternal() && global.isUninitialized()) {
630       Block *blk = new Block();
631       newGlobal.getInitializerRegion().push_back(blk);
632       rewriter.setInsertionPointToStart(blk);
633       Value undef[] = {
634           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
635       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
636     }
637     return success();
638   }
639 };
640 
641 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
642 /// the first element stashed into the descriptor. This reuses
643 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
644 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
645   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
646       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
647                                 converter) {}
648 
649   /// Buffer "allocation" for memref.get_global op is getting the address of
650   /// the global variable referenced.
651   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
652                                           Location loc, Value sizeBytes,
653                                           Operation *op) const override {
654     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
655     MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
656     unsigned memSpace = type.getMemorySpaceAsInt();
657 
658     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
659     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
660         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace),
661         getGlobalOp.getName());
662 
663     // Get the address of the first element in the array by creating a GEP with
664     // the address of the GV as the base, and (rank + 1) number of 0 indices.
665     Type elementType = typeConverter->convertType(type.getElementType());
666     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
667 
668     SmallVector<Value> operands;
669     operands.insert(operands.end(), type.getRank() + 1,
670                     createIndexConstant(rewriter, loc, 0));
671     auto gep =
672         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
673 
674     // We do not expect the memref obtained using `memref.get_global` to be
675     // ever deallocated. Set the allocated pointer to be known bad value to
676     // help debug if that ever happens.
677     auto intPtrType = getIntPtrType(memSpace);
678     Value deadBeefConst =
679         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
680     auto deadBeefPtr =
681         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
682 
683     // Both allocated and aligned pointers are same. We could potentially stash
684     // a nullptr for the allocated pointer since we do not expect any dealloc.
685     return std::make_tuple(deadBeefPtr, gep);
686   }
687 };
688 
689 // Load operation is lowered to obtaining a pointer to the indexed element
690 // and loading it.
691 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
692   using Base::Base;
693 
694   LogicalResult
695   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
696                   ConversionPatternRewriter &rewriter) const override {
697     auto type = loadOp.getMemRefType();
698 
699     Value dataPtr =
700         getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
701                              adaptor.getIndices(), rewriter);
702     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
703     return success();
704   }
705 };
706 
707 // Store operation is lowered to obtaining a pointer to the indexed element,
708 // and storing the given value to it.
709 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
710   using Base::Base;
711 
712   LogicalResult
713   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
714                   ConversionPatternRewriter &rewriter) const override {
715     auto type = op.getMemRefType();
716 
717     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
718                                          adaptor.getIndices(), rewriter);
719     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr);
720     return success();
721   }
722 };
723 
724 // The prefetch operation is lowered in a way similar to the load operation
725 // except that the llvm.prefetch operation is used for replacement.
726 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
727   using Base::Base;
728 
729   LogicalResult
730   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
731                   ConversionPatternRewriter &rewriter) const override {
732     auto type = prefetchOp.getMemRefType();
733     auto loc = prefetchOp.getLoc();
734 
735     Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
736                                          adaptor.getIndices(), rewriter);
737 
738     // Replace with llvm.prefetch.
739     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
740     auto isWrite = rewriter.create<LLVM::ConstantOp>(
741         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()));
742     auto localityHint = rewriter.create<LLVM::ConstantOp>(
743         loc, llvmI32Type,
744         rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint()));
745     auto isData = rewriter.create<LLVM::ConstantOp>(
746         loc, llvmI32Type,
747         rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache()));
748 
749     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
750                                                 localityHint, isData);
751     return success();
752   }
753 };
754 
755 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
756   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
757 
758   LogicalResult
759   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
760                   ConversionPatternRewriter &rewriter) const override {
761     Location loc = op.getLoc();
762     Type operandType = op.getMemref().getType();
763     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
764       UnrankedMemRefDescriptor desc(adaptor.getMemref());
765       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
766       return success();
767     }
768     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
769       rewriter.replaceOp(
770           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
771       return success();
772     }
773     return failure();
774   }
775 };
776 
777 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
778   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
779 
780   LogicalResult match(memref::CastOp memRefCastOp) const override {
781     Type srcType = memRefCastOp.getOperand().getType();
782     Type dstType = memRefCastOp.getType();
783 
784     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
785     // used for type erasure. For now they must preserve underlying element type
786     // and require source and result type to have the same rank. Therefore,
787     // perform a sanity check that the underlying structs are the same. Once op
788     // semantics are relaxed we can revisit.
789     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
790       return success(typeConverter->convertType(srcType) ==
791                      typeConverter->convertType(dstType));
792 
793     // At least one of the operands is unranked type
794     assert(srcType.isa<UnrankedMemRefType>() ||
795            dstType.isa<UnrankedMemRefType>());
796 
797     // Unranked to unranked cast is disallowed
798     return !(srcType.isa<UnrankedMemRefType>() &&
799              dstType.isa<UnrankedMemRefType>())
800                ? success()
801                : failure();
802   }
803 
804   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
805                ConversionPatternRewriter &rewriter) const override {
806     auto srcType = memRefCastOp.getOperand().getType();
807     auto dstType = memRefCastOp.getType();
808     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
809     auto loc = memRefCastOp.getLoc();
810 
811     // For ranked/ranked case, just keep the original descriptor.
812     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
813       return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
814 
815     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
816       // Casting ranked to unranked memref type
817       // Set the rank in the destination from the memref type
818       // Allocate space on the stack and copy the src memref descriptor
819       // Set the ptr in the destination to the stack space
820       auto srcMemRefType = srcType.cast<MemRefType>();
821       int64_t rank = srcMemRefType.getRank();
822       // ptr = AllocaOp sizeof(MemRefDescriptor)
823       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
824           loc, adaptor.getSource(), rewriter);
825       // voidptr = BitCastOp srcType* to void*
826       auto voidPtr =
827           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
828               .getResult();
829       // rank = ConstantOp srcRank
830       auto rankVal = rewriter.create<LLVM::ConstantOp>(
831           loc, getIndexType(), rewriter.getIndexAttr(rank));
832       // undef = UndefOp
833       UnrankedMemRefDescriptor memRefDesc =
834           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
835       // d1 = InsertValueOp undef, rank, 0
836       memRefDesc.setRank(rewriter, loc, rankVal);
837       // d2 = InsertValueOp d1, voidptr, 1
838       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
839       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
840 
841     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
842       // Casting from unranked type to ranked.
843       // The operation is assumed to be doing a correct cast. If the destination
844       // type mismatches the unranked the type, it is undefined behavior.
845       UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
846       // ptr = ExtractValueOp src, 1
847       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
848       // castPtr = BitCastOp i8* to structTy*
849       auto castPtr =
850           rewriter
851               .create<LLVM::BitcastOp>(
852                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
853               .getResult();
854       // struct = LoadOp castPtr
855       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
856       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
857     } else {
858       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
859     }
860   }
861 };
862 
863 /// Pattern to lower a `memref.copy` to llvm.
864 ///
865 /// For memrefs with identity layouts, the copy is lowered to the llvm
866 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
867 /// to the generic `MemrefCopyFn`.
868 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
869   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
870 
871   LogicalResult
872   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
873                           ConversionPatternRewriter &rewriter) const {
874     auto loc = op.getLoc();
875     auto srcType = op.getSource().getType().dyn_cast<MemRefType>();
876 
877     MemRefDescriptor srcDesc(adaptor.getSource());
878 
879     // Compute number of elements.
880     Value numElements = rewriter.create<LLVM::ConstantOp>(
881         loc, getIndexType(), rewriter.getIndexAttr(1));
882     for (int pos = 0; pos < srcType.getRank(); ++pos) {
883       auto size = srcDesc.size(rewriter, loc, pos);
884       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
885     }
886 
887     // Get element size.
888     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
889     // Compute total.
890     Value totalSize =
891         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
892 
893     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
894     Value srcOffset = srcDesc.offset(rewriter, loc);
895     Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
896                                                 srcBasePtr, srcOffset);
897     MemRefDescriptor targetDesc(adaptor.getTarget());
898     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
899     Value targetOffset = targetDesc.offset(rewriter, loc);
900     Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
901                                                    targetBasePtr, targetOffset);
902     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
903         loc, typeConverter->convertType(rewriter.getI1Type()),
904         rewriter.getBoolAttr(false));
905     rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
906                                     isVolatile);
907     rewriter.eraseOp(op);
908 
909     return success();
910   }
911 
912   LogicalResult
913   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
914                              ConversionPatternRewriter &rewriter) const {
915     auto loc = op.getLoc();
916     auto srcType = op.getSource().getType().cast<BaseMemRefType>();
917     auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
918 
919     // First make sure we have an unranked memref descriptor representation.
920     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
921       auto rank = rewriter.create<LLVM::ConstantOp>(
922           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
923       auto *typeConverter = getTypeConverter();
924       auto ptr =
925           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
926       auto voidPtr =
927           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
928               .getResult();
929       auto unrankedType =
930           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
931       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
932                                             unrankedType,
933                                             ValueRange{rank, voidPtr});
934     };
935 
936     Value unrankedSource = srcType.hasRank()
937                                ? makeUnranked(adaptor.getSource(), srcType)
938                                : adaptor.getSource();
939     Value unrankedTarget = targetType.hasRank()
940                                ? makeUnranked(adaptor.getTarget(), targetType)
941                                : adaptor.getTarget();
942 
943     // Now promote the unranked descriptors to the stack.
944     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
945                                                  rewriter.getIndexAttr(1));
946     auto promote = [&](Value desc) {
947       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
948       auto allocated =
949           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
950       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
951       return allocated;
952     };
953 
954     auto sourcePtr = promote(unrankedSource);
955     auto targetPtr = promote(unrankedTarget);
956 
957     unsigned typeSize =
958         mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
959     auto elemSize = rewriter.create<LLVM::ConstantOp>(
960         loc, getIndexType(), rewriter.getIndexAttr(typeSize));
961     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
962         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
963     rewriter.create<LLVM::CallOp>(loc, copyFn,
964                                   ValueRange{elemSize, sourcePtr, targetPtr});
965     rewriter.eraseOp(op);
966 
967     return success();
968   }
969 
970   LogicalResult
971   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
972                   ConversionPatternRewriter &rewriter) const override {
973     auto srcType = op.getSource().getType().cast<BaseMemRefType>();
974     auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
975 
976     auto isContiguousMemrefType = [](BaseMemRefType type) {
977       auto memrefType = type.dyn_cast<mlir::MemRefType>();
978       // We can use memcpy for memrefs if they have an identity layout or are
979       // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
980       // special case handled by memrefCopy.
981       return memrefType &&
982              (memrefType.getLayout().isIdentity() ||
983               (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
984                isStaticShapeAndContiguousRowMajor(memrefType)));
985     };
986 
987     if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
988       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
989 
990     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
991   }
992 };
993 
994 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
995 /// memref type. In unranked case, the fields are extracted from the underlying
996 /// ranked descriptor.
997 static void extractPointersAndOffset(Location loc,
998                                      ConversionPatternRewriter &rewriter,
999                                      LLVMTypeConverter &typeConverter,
1000                                      Value originalOperand,
1001                                      Value convertedOperand,
1002                                      Value *allocatedPtr, Value *alignedPtr,
1003                                      Value *offset = nullptr) {
1004   Type operandType = originalOperand.getType();
1005   if (operandType.isa<MemRefType>()) {
1006     MemRefDescriptor desc(convertedOperand);
1007     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1008     *alignedPtr = desc.alignedPtr(rewriter, loc);
1009     if (offset != nullptr)
1010       *offset = desc.offset(rewriter, loc);
1011     return;
1012   }
1013 
1014   unsigned memorySpace =
1015       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
1016   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
1017   Type llvmElementType = typeConverter.convertType(elementType);
1018   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
1019       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
1020 
1021   // Extract pointer to the underlying ranked memref descriptor and cast it to
1022   // ElemType**.
1023   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1024   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1025 
1026   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
1027       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
1028   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1029       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1030   if (offset != nullptr) {
1031     *offset = UnrankedMemRefDescriptor::offset(
1032         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1033   }
1034 }
1035 
1036 struct MemRefReinterpretCastOpLowering
1037     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1038   using ConvertOpToLLVMPattern<
1039       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1040 
1041   LogicalResult
1042   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1043                   ConversionPatternRewriter &rewriter) const override {
1044     Type srcType = castOp.getSource().getType();
1045 
1046     Value descriptor;
1047     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1048                                                adaptor, &descriptor)))
1049       return failure();
1050     rewriter.replaceOp(castOp, {descriptor});
1051     return success();
1052   }
1053 
1054 private:
1055   LogicalResult convertSourceMemRefToDescriptor(
1056       ConversionPatternRewriter &rewriter, Type srcType,
1057       memref::ReinterpretCastOp castOp,
1058       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1059     MemRefType targetMemRefType =
1060         castOp.getResult().getType().cast<MemRefType>();
1061     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1062                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1063     if (!llvmTargetDescriptorTy)
1064       return failure();
1065 
1066     // Create descriptor.
1067     Location loc = castOp.getLoc();
1068     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1069 
1070     // Set allocated and aligned pointers.
1071     Value allocatedPtr, alignedPtr;
1072     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1073                              castOp.getSource(), adaptor.getSource(),
1074                              &allocatedPtr, &alignedPtr);
1075     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1076     desc.setAlignedPtr(rewriter, loc, alignedPtr);
1077 
1078     // Set offset.
1079     if (castOp.isDynamicOffset(0))
1080       desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1081     else
1082       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1083 
1084     // Set sizes and strides.
1085     unsigned dynSizeId = 0;
1086     unsigned dynStrideId = 0;
1087     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1088       if (castOp.isDynamicSize(i))
1089         desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1090       else
1091         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1092 
1093       if (castOp.isDynamicStride(i))
1094         desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1095       else
1096         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1097     }
1098     *descriptor = desc;
1099     return success();
1100   }
1101 };
1102 
1103 struct MemRefReshapeOpLowering
1104     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1105   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1106 
1107   LogicalResult
1108   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1109                   ConversionPatternRewriter &rewriter) const override {
1110     Type srcType = reshapeOp.getSource().getType();
1111 
1112     Value descriptor;
1113     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1114                                                adaptor, &descriptor)))
1115       return failure();
1116     rewriter.replaceOp(reshapeOp, {descriptor});
1117     return success();
1118   }
1119 
1120 private:
1121   LogicalResult
1122   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1123                                   Type srcType, memref::ReshapeOp reshapeOp,
1124                                   memref::ReshapeOp::Adaptor adaptor,
1125                                   Value *descriptor) const {
1126     auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>();
1127     if (shapeMemRefType.hasStaticShape()) {
1128       MemRefType targetMemRefType =
1129           reshapeOp.getResult().getType().cast<MemRefType>();
1130       auto llvmTargetDescriptorTy =
1131           typeConverter->convertType(targetMemRefType)
1132               .dyn_cast_or_null<LLVM::LLVMStructType>();
1133       if (!llvmTargetDescriptorTy)
1134         return failure();
1135 
1136       // Create descriptor.
1137       Location loc = reshapeOp.getLoc();
1138       auto desc =
1139           MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1140 
1141       // Set allocated and aligned pointers.
1142       Value allocatedPtr, alignedPtr;
1143       extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1144                                reshapeOp.getSource(), adaptor.getSource(),
1145                                &allocatedPtr, &alignedPtr);
1146       desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1147       desc.setAlignedPtr(rewriter, loc, alignedPtr);
1148 
1149       // Extract the offset and strides from the type.
1150       int64_t offset;
1151       SmallVector<int64_t> strides;
1152       if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1153         return rewriter.notifyMatchFailure(
1154             reshapeOp, "failed to get stride and offset exprs");
1155 
1156       if (!isStaticStrideOrOffset(offset))
1157         return rewriter.notifyMatchFailure(reshapeOp,
1158                                            "dynamic offset is unsupported");
1159 
1160       desc.setConstantOffset(rewriter, loc, offset);
1161 
1162       assert(targetMemRefType.getLayout().isIdentity() &&
1163              "Identity layout map is a precondition of a valid reshape op");
1164 
1165       Value stride = nullptr;
1166       int64_t targetRank = targetMemRefType.getRank();
1167       for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1168         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1169           // If the stride for this dimension is dynamic, then use the product
1170           // of the sizes of the inner dimensions.
1171           stride = createIndexConstant(rewriter, loc, strides[i]);
1172         } else if (!stride) {
1173           // `stride` is null only in the first iteration of the loop.  However,
1174           // since the target memref has an identity layout, we can safely set
1175           // the innermost stride to 1.
1176           stride = createIndexConstant(rewriter, loc, 1);
1177         }
1178 
1179         Value dimSize;
1180         int64_t size = targetMemRefType.getDimSize(i);
1181         // If the size of this dimension is dynamic, then load it at runtime
1182         // from the shape operand.
1183         if (!ShapedType::isDynamic(size)) {
1184           dimSize = createIndexConstant(rewriter, loc, size);
1185         } else {
1186           Value shapeOp = reshapeOp.getShape();
1187           Value index = createIndexConstant(rewriter, loc, i);
1188           dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1189           Type indexType = getIndexType();
1190           if (dimSize.getType() != indexType)
1191             dimSize = typeConverter->materializeTargetConversion(
1192                 rewriter, loc, indexType, dimSize);
1193           assert(dimSize && "Invalid memref element type");
1194         }
1195 
1196         desc.setSize(rewriter, loc, i, dimSize);
1197         desc.setStride(rewriter, loc, i, stride);
1198 
1199         // Prepare the stride value for the next dimension.
1200         stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1201       }
1202 
1203       *descriptor = desc;
1204       return success();
1205     }
1206 
1207     // The shape is a rank-1 tensor with unknown length.
1208     Location loc = reshapeOp.getLoc();
1209     MemRefDescriptor shapeDesc(adaptor.getShape());
1210     Value resultRank = shapeDesc.size(rewriter, loc, 0);
1211 
1212     // Extract address space and element type.
1213     auto targetType =
1214         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1215     unsigned addressSpace = targetType.getMemorySpaceAsInt();
1216     Type elementType = targetType.getElementType();
1217 
1218     // Create the unranked memref descriptor that holds the ranked one. The
1219     // inner descriptor is allocated on stack.
1220     auto targetDesc = UnrankedMemRefDescriptor::undef(
1221         rewriter, loc, typeConverter->convertType(targetType));
1222     targetDesc.setRank(rewriter, loc, resultRank);
1223     SmallVector<Value, 4> sizes;
1224     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1225                                            targetDesc, sizes);
1226     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1227         loc, getVoidPtrType(), sizes.front(), llvm::None);
1228     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1229 
1230     // Extract pointers and offset from the source memref.
1231     Value allocatedPtr, alignedPtr, offset;
1232     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1233                              reshapeOp.getSource(), adaptor.getSource(),
1234                              &allocatedPtr, &alignedPtr, &offset);
1235 
1236     // Set pointers and offset.
1237     Type llvmElementType = typeConverter->convertType(elementType);
1238     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1239         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1240     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1241                                               elementPtrPtrType, allocatedPtr);
1242     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1243                                             underlyingDescPtr,
1244                                             elementPtrPtrType, alignedPtr);
1245     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1246                                         underlyingDescPtr, elementPtrPtrType,
1247                                         offset);
1248 
1249     // Use the offset pointer as base for further addressing. Copy over the new
1250     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1251     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1252         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1253         elementPtrPtrType);
1254     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1255         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1256     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1257     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1258     Value resultRankMinusOne =
1259         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1260 
1261     Block *initBlock = rewriter.getInsertionBlock();
1262     Type indexType = getTypeConverter()->getIndexType();
1263     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1264 
1265     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1266                                             {indexType, indexType}, {loc, loc});
1267 
1268     // Move the remaining initBlock ops to condBlock.
1269     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1270     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1271 
1272     rewriter.setInsertionPointToEnd(initBlock);
1273     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1274                                 condBlock);
1275     rewriter.setInsertionPointToStart(condBlock);
1276     Value indexArg = condBlock->getArgument(0);
1277     Value strideArg = condBlock->getArgument(1);
1278 
1279     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1280     Value pred = rewriter.create<LLVM::ICmpOp>(
1281         loc, IntegerType::get(rewriter.getContext(), 1),
1282         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1283 
1284     Block *bodyBlock =
1285         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1286     rewriter.setInsertionPointToStart(bodyBlock);
1287 
1288     // Copy size from shape to descriptor.
1289     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1290     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1291         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1292     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1293     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1294                                       targetSizesBase, indexArg, size);
1295 
1296     // Write stride value and compute next one.
1297     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1298                                         targetStridesBase, indexArg, strideArg);
1299     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1300 
1301     // Decrement loop counter and branch back.
1302     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1303     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1304                                 condBlock);
1305 
1306     Block *remainder =
1307         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1308 
1309     // Hook up the cond exit to the remainder.
1310     rewriter.setInsertionPointToEnd(condBlock);
1311     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1312                                     llvm::None);
1313 
1314     // Reset position to beginning of new remainder block.
1315     rewriter.setInsertionPointToStart(remainder);
1316 
1317     *descriptor = targetDesc;
1318     return success();
1319   }
1320 };
1321 
1322 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1323 /// `Value`s.
1324 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1325                                       Type &llvmIndexType,
1326                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1327   return llvm::to_vector<4>(
1328       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1329         if (auto attr = value.dyn_cast<Attribute>())
1330           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1331         return value.get<Value>();
1332       }));
1333 }
1334 
1335 /// Compute a map that for a given dimension of the expanded type gives the
1336 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1337 /// the `reassocation` maps.
1338 static DenseMap<int64_t, int64_t>
1339 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1340   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1341   for (auto &en : enumerate(reassociation)) {
1342     for (auto dim : en.value())
1343       expandedDimToCollapsedDim[dim] = en.index();
1344   }
1345   return expandedDimToCollapsedDim;
1346 }
1347 
1348 static OpFoldResult
1349 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1350                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1351                          MemRefDescriptor &inDesc,
1352                          ArrayRef<int64_t> inStaticShape,
1353                          ArrayRef<ReassociationIndices> reassocation,
1354                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1355   int64_t outDimSize = outStaticShape[outDimIndex];
1356   if (!ShapedType::isDynamic(outDimSize))
1357     return b.getIndexAttr(outDimSize);
1358 
1359   // Calculate the multiplication of all the out dim sizes except the
1360   // current dim.
1361   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1362   int64_t otherDimSizesMul = 1;
1363   for (auto otherDimIndex : reassocation[inDimIndex]) {
1364     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1365       continue;
1366     int64_t otherDimSize = outStaticShape[otherDimIndex];
1367     assert(!ShapedType::isDynamic(otherDimSize) &&
1368            "single dimension cannot be expanded into multiple dynamic "
1369            "dimensions");
1370     otherDimSizesMul *= otherDimSize;
1371   }
1372 
1373   // outDimSize = inDimSize / otherOutDimSizesMul
1374   int64_t inDimSize = inStaticShape[inDimIndex];
1375   Value inDimSizeDynamic =
1376       ShapedType::isDynamic(inDimSize)
1377           ? inDesc.size(b, loc, inDimIndex)
1378           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1379                                        b.getIndexAttr(inDimSize));
1380   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1381       loc, inDimSizeDynamic,
1382       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1383                                  b.getIndexAttr(otherDimSizesMul)));
1384   return outDimSizeDynamic;
1385 }
1386 
1387 static OpFoldResult getCollapsedOutputDimSize(
1388     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1389     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1390     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1391   if (!ShapedType::isDynamic(outDimSize))
1392     return b.getIndexAttr(outDimSize);
1393 
1394   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1395   Value outDimSizeDynamic = c1;
1396   for (auto inDimIndex : reassocation[outDimIndex]) {
1397     int64_t inDimSize = inStaticShape[inDimIndex];
1398     Value inDimSizeDynamic =
1399         ShapedType::isDynamic(inDimSize)
1400             ? inDesc.size(b, loc, inDimIndex)
1401             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1402                                          b.getIndexAttr(inDimSize));
1403     outDimSizeDynamic =
1404         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1405   }
1406   return outDimSizeDynamic;
1407 }
1408 
1409 static SmallVector<OpFoldResult, 4>
1410 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1411                         ArrayRef<ReassociationIndices> reassociation,
1412                         ArrayRef<int64_t> inStaticShape,
1413                         MemRefDescriptor &inDesc,
1414                         ArrayRef<int64_t> outStaticShape) {
1415   return llvm::to_vector<4>(llvm::map_range(
1416       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1417         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1418                                          outStaticShape[outDimIndex],
1419                                          inStaticShape, inDesc, reassociation);
1420       }));
1421 }
1422 
1423 static SmallVector<OpFoldResult, 4>
1424 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1425                        ArrayRef<ReassociationIndices> reassociation,
1426                        ArrayRef<int64_t> inStaticShape,
1427                        MemRefDescriptor &inDesc,
1428                        ArrayRef<int64_t> outStaticShape) {
1429   DenseMap<int64_t, int64_t> outDimToInDimMap =
1430       getExpandedDimToCollapsedDimMap(reassociation);
1431   return llvm::to_vector<4>(llvm::map_range(
1432       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1433         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1434                                         outStaticShape, inDesc, inStaticShape,
1435                                         reassociation, outDimToInDimMap);
1436       }));
1437 }
1438 
1439 static SmallVector<Value>
1440 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1441                       ArrayRef<ReassociationIndices> reassociation,
1442                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1443                       ArrayRef<int64_t> outStaticShape) {
1444   return outStaticShape.size() < inStaticShape.size()
1445              ? getAsValues(b, loc, llvmIndexType,
1446                            getCollapsedOutputShape(b, loc, llvmIndexType,
1447                                                    reassociation, inStaticShape,
1448                                                    inDesc, outStaticShape))
1449              : getAsValues(b, loc, llvmIndexType,
1450                            getExpandedOutputShape(b, loc, llvmIndexType,
1451                                                   reassociation, inStaticShape,
1452                                                   inDesc, outStaticShape));
1453 }
1454 
1455 static void fillInStridesForExpandedMemDescriptor(
1456     OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
1457     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1458   // See comments for computeExpandedLayoutMap for details on how the strides
1459   // are calculated.
1460   for (auto &en : llvm::enumerate(reassociation)) {
1461     auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
1462     for (auto dstIndex : llvm::reverse(en.value())) {
1463       dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
1464       Value size = dstDesc.size(b, loc, dstIndex);
1465       currentStrideToExpand =
1466           b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
1467     }
1468   }
1469 }
1470 
1471 static void fillInStridesForCollapsedMemDescriptor(
1472     ConversionPatternRewriter &rewriter, Location loc, Operation *op,
1473     TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
1474     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1475   // See comments for computeCollapsedLayoutMap for details on how the strides
1476   // are calculated.
1477   auto srcShape = srcType.getShape();
1478   for (auto &en : llvm::enumerate(reassociation)) {
1479     rewriter.setInsertionPoint(op);
1480     auto dstIndex = en.index();
1481     ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value());
1482     while (srcShape[ref.back()] == 1 && ref.size() > 1)
1483       ref = ref.drop_back();
1484     if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1485       dstDesc.setStride(rewriter, loc, dstIndex,
1486                         srcDesc.stride(rewriter, loc, ref.back()));
1487     } else {
1488       // Iterate over the source strides in reverse order. Skip over the
1489       // dimensions whose size is 1.
1490       // TODO: we should take the minimum stride in the reassociation group
1491       // instead of just the first where the dimension is not 1.
1492       //
1493       // +------------------------------------------------------+
1494       // | curEntry:                                            |
1495       // |   %srcStride = strides[srcIndex]                     |
1496       // |   %neOne = cmp sizes[srcIndex],1                     +--+
1497       // |   cf.cond_br %neOne, continue(%srcStride), nextEntry |  |
1498       // +-------------------------+----------------------------+  |
1499       //                           |                               |
1500       //                           v                               |
1501       //            +-----------------------------+                |
1502       //            | nextEntry:                  |                |
1503       //            |   ...                       +---+            |
1504       //            +--------------+--------------+   |            |
1505       //                           |                  |            |
1506       //                           v                  |            |
1507       //            +-----------------------------+   |            |
1508       //            | nextEntry:                  |   |            |
1509       //            |   ...                       |   |            |
1510       //            +--------------+--------------+   |   +--------+
1511       //                           |                  |   |
1512       //                           v                  v   v
1513       //   +--------------------------------------------------+
1514       //   | continue(%newStride):                            |
1515       //   |   %newMemRefDes = setStride(%newStride,dstIndex) |
1516       //   +--------------------------------------------------+
1517       OpBuilder::InsertionGuard guard(rewriter);
1518       Block *initBlock = rewriter.getInsertionBlock();
1519       Block *continueBlock =
1520           rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
1521       continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
1522       rewriter.setInsertionPointToStart(continueBlock);
1523       dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
1524 
1525       Block *curEntryBlock = initBlock;
1526       Block *nextEntryBlock;
1527       for (auto srcIndex : llvm::reverse(ref)) {
1528         if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
1529           continue;
1530         rewriter.setInsertionPointToEnd(curEntryBlock);
1531         Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
1532         if (srcIndex == ref.front()) {
1533           rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
1534           break;
1535         }
1536         Value one = rewriter.create<LLVM::ConstantOp>(
1537             loc, typeConverter->convertType(rewriter.getI64Type()),
1538             rewriter.getI32IntegerAttr(1));
1539         Value predNeOne = rewriter.create<LLVM::ICmpOp>(
1540             loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
1541             one);
1542         {
1543           OpBuilder::InsertionGuard guard(rewriter);
1544           nextEntryBlock = rewriter.createBlock(
1545               initBlock->getParent(), Region::iterator(continueBlock), {});
1546         }
1547         rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
1548                                         srcStride, nextEntryBlock, llvm::None);
1549         curEntryBlock = nextEntryBlock;
1550       }
1551     }
1552   }
1553 }
1554 
1555 static void fillInDynamicStridesForMemDescriptor(
1556     ConversionPatternRewriter &b, Location loc, Operation *op,
1557     TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
1558     MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
1559     ArrayRef<ReassociationIndices> reassociation) {
1560   if (srcType.getRank() > dstType.getRank())
1561     fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
1562                                            srcDesc, dstDesc, reassociation);
1563   else
1564     fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
1565                                           reassociation);
1566 }
1567 
1568 // ReshapeOp creates a new view descriptor of the proper rank.
1569 // For now, the only conversion supported is for target MemRef with static sizes
1570 // and strides.
1571 template <typename ReshapeOp>
1572 class ReassociatingReshapeOpConversion
1573     : public ConvertOpToLLVMPattern<ReshapeOp> {
1574 public:
1575   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1576   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1577 
1578   LogicalResult
1579   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1580                   ConversionPatternRewriter &rewriter) const override {
1581     MemRefType dstType = reshapeOp.getResultType();
1582     MemRefType srcType = reshapeOp.getSrcType();
1583 
1584     int64_t offset;
1585     SmallVector<int64_t, 4> strides;
1586     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1587       return rewriter.notifyMatchFailure(
1588           reshapeOp, "failed to get stride and offset exprs");
1589     }
1590 
1591     MemRefDescriptor srcDesc(adaptor.getSrc());
1592     Location loc = reshapeOp->getLoc();
1593     auto dstDesc = MemRefDescriptor::undef(
1594         rewriter, loc, this->typeConverter->convertType(dstType));
1595     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1596     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1597     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1598 
1599     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1600     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1601     Type llvmIndexType =
1602         this->typeConverter->convertType(rewriter.getIndexType());
1603     SmallVector<Value> dstShape = getDynamicOutputShape(
1604         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1605         srcStaticShape, srcDesc, dstStaticShape);
1606     for (auto &en : llvm::enumerate(dstShape))
1607       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1608 
1609     if (llvm::all_of(strides, isStaticStrideOrOffset)) {
1610       for (auto &en : llvm::enumerate(strides))
1611         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1612     } else if (srcType.getLayout().isIdentity() &&
1613                dstType.getLayout().isIdentity()) {
1614       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1615                                                    rewriter.getIndexAttr(1));
1616       Value stride = c1;
1617       for (auto dimIndex :
1618            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1619         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1620         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1621       }
1622     } else {
1623       // There could be mixed static/dynamic strides. For simplicity, we
1624       // recompute all strides if there is at least one dynamic stride.
1625       fillInDynamicStridesForMemDescriptor(
1626           rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
1627           srcDesc, dstDesc, reshapeOp.getReassociationIndices());
1628     }
1629     rewriter.replaceOp(reshapeOp, {dstDesc});
1630     return success();
1631   }
1632 };
1633 
1634 /// Conversion pattern that transforms a subview op into:
1635 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1636 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1637 ///      and stride.
1638 /// The subview op is replaced by the descriptor.
1639 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1640   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1641 
1642   LogicalResult
1643   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1644                   ConversionPatternRewriter &rewriter) const override {
1645     auto loc = subViewOp.getLoc();
1646 
1647     auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>();
1648     auto sourceElementTy =
1649         typeConverter->convertType(sourceMemRefType.getElementType());
1650 
1651     auto viewMemRefType = subViewOp.getType();
1652     auto inferredType =
1653         memref::SubViewOp::inferResultType(
1654             subViewOp.getSourceType(),
1655             extractFromI64ArrayAttr(subViewOp.getStaticOffsets()),
1656             extractFromI64ArrayAttr(subViewOp.getStaticSizes()),
1657             extractFromI64ArrayAttr(subViewOp.getStaticStrides()))
1658             .cast<MemRefType>();
1659     auto targetElementTy =
1660         typeConverter->convertType(viewMemRefType.getElementType());
1661     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1662     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1663         !LLVM::isCompatibleType(sourceElementTy) ||
1664         !LLVM::isCompatibleType(targetElementTy) ||
1665         !LLVM::isCompatibleType(targetDescTy))
1666       return failure();
1667 
1668     // Extract the offset and strides from the type.
1669     int64_t offset;
1670     SmallVector<int64_t, 4> strides;
1671     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1672     if (failed(successStrides))
1673       return failure();
1674 
1675     // Create the descriptor.
1676     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1677       return failure();
1678     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1679     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1680 
1681     // Copy the buffer pointer from the old descriptor to the new one.
1682     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1683     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1684         loc,
1685         LLVM::LLVMPointerType::get(targetElementTy,
1686                                    viewMemRefType.getMemorySpaceAsInt()),
1687         extracted);
1688     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1689 
1690     // Copy the aligned pointer from the old descriptor to the new one.
1691     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1692     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1693         loc,
1694         LLVM::LLVMPointerType::get(targetElementTy,
1695                                    viewMemRefType.getMemorySpaceAsInt()),
1696         extracted);
1697     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1698 
1699     size_t inferredShapeRank = inferredType.getRank();
1700     size_t resultShapeRank = viewMemRefType.getRank();
1701 
1702     // Extract strides needed to compute offset.
1703     SmallVector<Value, 4> strideValues;
1704     strideValues.reserve(inferredShapeRank);
1705     for (unsigned i = 0; i < inferredShapeRank; ++i)
1706       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1707 
1708     // Offset.
1709     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1710     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1711       targetMemRef.setConstantOffset(rewriter, loc, offset);
1712     } else {
1713       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1714       // `inferredShapeRank` may be larger than the number of offset operands
1715       // because of trailing semantics. In this case, the offset is guaranteed
1716       // to be interpreted as 0 and we can just skip the extra dimensions.
1717       for (unsigned i = 0, e = std::min(inferredShapeRank,
1718                                         subViewOp.getMixedOffsets().size());
1719            i < e; ++i) {
1720         Value offset =
1721             // TODO: need OpFoldResult ODS adaptor to clean this up.
1722             subViewOp.isDynamicOffset(i)
1723                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1724                 : rewriter.create<LLVM::ConstantOp>(
1725                       loc, llvmIndexType,
1726                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1727         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1728         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1729       }
1730       targetMemRef.setOffset(rewriter, loc, baseOffset);
1731     }
1732 
1733     // Update sizes and strides.
1734     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1735     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1736     assert(mixedSizes.size() == mixedStrides.size() &&
1737            "expected sizes and strides of equal length");
1738     llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1739     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1740          i >= 0 && j >= 0; --i) {
1741       if (unusedDims.test(i))
1742         continue;
1743 
1744       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1745       // In this case, the size is guaranteed to be interpreted as Dim and the
1746       // stride as 1.
1747       Value size, stride;
1748       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1749         // If the static size is available, use it directly. This is similar to
1750         // the folding of dim(constant-op) but removes the need for dim to be
1751         // aware of LLVM constants and for this pass to be aware of std
1752         // constants.
1753         int64_t staticSize =
1754             subViewOp.getSource().getType().cast<MemRefType>().getShape()[i];
1755         if (staticSize != ShapedType::kDynamicSize) {
1756           size = rewriter.create<LLVM::ConstantOp>(
1757               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1758         } else {
1759           Value pos = rewriter.create<LLVM::ConstantOp>(
1760               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1761           Value dim =
1762               rewriter.create<memref::DimOp>(loc, subViewOp.getSource(), pos);
1763           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1764               loc, llvmIndexType, dim);
1765           size = cast.getResult(0);
1766         }
1767         stride = rewriter.create<LLVM::ConstantOp>(
1768             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1769       } else {
1770         // TODO: need OpFoldResult ODS adaptor to clean this up.
1771         size =
1772             subViewOp.isDynamicSize(i)
1773                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1774                 : rewriter.create<LLVM::ConstantOp>(
1775                       loc, llvmIndexType,
1776                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1777         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1778           stride = rewriter.create<LLVM::ConstantOp>(
1779               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1780         } else {
1781           stride =
1782               subViewOp.isDynamicStride(i)
1783                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1784                   : rewriter.create<LLVM::ConstantOp>(
1785                         loc, llvmIndexType,
1786                         rewriter.getI64IntegerAttr(
1787                             subViewOp.getStaticStride(i)));
1788           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1789         }
1790       }
1791       targetMemRef.setSize(rewriter, loc, j, size);
1792       targetMemRef.setStride(rewriter, loc, j, stride);
1793       j--;
1794     }
1795 
1796     rewriter.replaceOp(subViewOp, {targetMemRef});
1797     return success();
1798   }
1799 };
1800 
1801 /// Conversion pattern that transforms a transpose op into:
1802 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1803 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1804 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1805 ///      and stride. Size and stride are permutations of the original values.
1806 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1807 /// The transpose op is replaced by the alloca'ed pointer.
1808 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1809 public:
1810   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1811 
1812   LogicalResult
1813   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1814                   ConversionPatternRewriter &rewriter) const override {
1815     auto loc = transposeOp.getLoc();
1816     MemRefDescriptor viewMemRef(adaptor.getIn());
1817 
1818     // No permutation, early exit.
1819     if (transposeOp.getPermutation().isIdentity())
1820       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1821 
1822     auto targetMemRef = MemRefDescriptor::undef(
1823         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1824 
1825     // Copy the base and aligned pointers from the old descriptor to the new
1826     // one.
1827     targetMemRef.setAllocatedPtr(rewriter, loc,
1828                                  viewMemRef.allocatedPtr(rewriter, loc));
1829     targetMemRef.setAlignedPtr(rewriter, loc,
1830                                viewMemRef.alignedPtr(rewriter, loc));
1831 
1832     // Copy the offset pointer from the old descriptor to the new one.
1833     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1834 
1835     // Iterate over the dimensions and apply size/stride permutation.
1836     for (const auto &en :
1837          llvm::enumerate(transposeOp.getPermutation().getResults())) {
1838       int sourcePos = en.index();
1839       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1840       targetMemRef.setSize(rewriter, loc, targetPos,
1841                            viewMemRef.size(rewriter, loc, sourcePos));
1842       targetMemRef.setStride(rewriter, loc, targetPos,
1843                              viewMemRef.stride(rewriter, loc, sourcePos));
1844     }
1845 
1846     rewriter.replaceOp(transposeOp, {targetMemRef});
1847     return success();
1848   }
1849 };
1850 
1851 /// Conversion pattern that transforms an op into:
1852 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1853 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1854 ///      and stride.
1855 /// The view op is replaced by the descriptor.
1856 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1857   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1858 
1859   // Build and return the value for the idx^th shape dimension, either by
1860   // returning the constant shape dimension or counting the proper dynamic size.
1861   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1862                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1863                 unsigned idx) const {
1864     assert(idx < shape.size());
1865     if (!ShapedType::isDynamic(shape[idx]))
1866       return createIndexConstant(rewriter, loc, shape[idx]);
1867     // Count the number of dynamic dims in range [0, idx]
1868     unsigned nDynamic =
1869         llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1870     return dynamicSizes[nDynamic];
1871   }
1872 
1873   // Build and return the idx^th stride, either by returning the constant stride
1874   // or by computing the dynamic stride from the current `runningStride` and
1875   // `nextSize`. The caller should keep a running stride and update it with the
1876   // result returned by this function.
1877   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1878                   ArrayRef<int64_t> strides, Value nextSize,
1879                   Value runningStride, unsigned idx) const {
1880     assert(idx < strides.size());
1881     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1882       return createIndexConstant(rewriter, loc, strides[idx]);
1883     if (nextSize)
1884       return runningStride
1885                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1886                  : nextSize;
1887     assert(!runningStride);
1888     return createIndexConstant(rewriter, loc, 1);
1889   }
1890 
1891   LogicalResult
1892   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1893                   ConversionPatternRewriter &rewriter) const override {
1894     auto loc = viewOp.getLoc();
1895 
1896     auto viewMemRefType = viewOp.getType();
1897     auto targetElementTy =
1898         typeConverter->convertType(viewMemRefType.getElementType());
1899     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1900     if (!targetDescTy || !targetElementTy ||
1901         !LLVM::isCompatibleType(targetElementTy) ||
1902         !LLVM::isCompatibleType(targetDescTy))
1903       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1904              failure();
1905 
1906     int64_t offset;
1907     SmallVector<int64_t, 4> strides;
1908     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1909     if (failed(successStrides))
1910       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1911     assert(offset == 0 && "expected offset to be 0");
1912 
1913     // Target memref must be contiguous in memory (innermost stride is 1), or
1914     // empty (special case when at least one of the memref dimensions is 0).
1915     if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1916       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1917              failure();
1918 
1919     // Create the descriptor.
1920     MemRefDescriptor sourceMemRef(adaptor.getSource());
1921     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1922 
1923     // Field 1: Copy the allocated pointer, used for malloc/free.
1924     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1925     auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
1926     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1927         loc,
1928         LLVM::LLVMPointerType::get(targetElementTy,
1929                                    srcMemRefType.getMemorySpaceAsInt()),
1930         allocatedPtr);
1931     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1932 
1933     // Field 2: Copy the actual aligned pointer to payload.
1934     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1935     alignedPtr = rewriter.create<LLVM::GEPOp>(
1936         loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift());
1937     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1938         loc,
1939         LLVM::LLVMPointerType::get(targetElementTy,
1940                                    srcMemRefType.getMemorySpaceAsInt()),
1941         alignedPtr);
1942     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1943 
1944     // Field 3: The offset in the resulting type must be 0. This is because of
1945     // the type change: an offset on srcType* may not be expressible as an
1946     // offset on dstType*.
1947     targetMemRef.setOffset(rewriter, loc,
1948                            createIndexConstant(rewriter, loc, offset));
1949 
1950     // Early exit for 0-D corner case.
1951     if (viewMemRefType.getRank() == 0)
1952       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1953 
1954     // Fields 4 and 5: Update sizes and strides.
1955     Value stride = nullptr, nextSize = nullptr;
1956     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1957       // Update size.
1958       Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1959                            adaptor.getSizes(), i);
1960       targetMemRef.setSize(rewriter, loc, i, size);
1961       // Update stride.
1962       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1963       targetMemRef.setStride(rewriter, loc, i, stride);
1964       nextSize = size;
1965     }
1966 
1967     rewriter.replaceOp(viewOp, {targetMemRef});
1968     return success();
1969   }
1970 };
1971 
1972 //===----------------------------------------------------------------------===//
1973 // AtomicRMWOpLowering
1974 //===----------------------------------------------------------------------===//
1975 
1976 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1977 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1978 static Optional<LLVM::AtomicBinOp>
1979 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1980   switch (atomicOp.getKind()) {
1981   case arith::AtomicRMWKind::addf:
1982     return LLVM::AtomicBinOp::fadd;
1983   case arith::AtomicRMWKind::addi:
1984     return LLVM::AtomicBinOp::add;
1985   case arith::AtomicRMWKind::assign:
1986     return LLVM::AtomicBinOp::xchg;
1987   case arith::AtomicRMWKind::maxs:
1988     return LLVM::AtomicBinOp::max;
1989   case arith::AtomicRMWKind::maxu:
1990     return LLVM::AtomicBinOp::umax;
1991   case arith::AtomicRMWKind::mins:
1992     return LLVM::AtomicBinOp::min;
1993   case arith::AtomicRMWKind::minu:
1994     return LLVM::AtomicBinOp::umin;
1995   case arith::AtomicRMWKind::ori:
1996     return LLVM::AtomicBinOp::_or;
1997   case arith::AtomicRMWKind::andi:
1998     return LLVM::AtomicBinOp::_and;
1999   default:
2000     return llvm::None;
2001   }
2002   llvm_unreachable("Invalid AtomicRMWKind");
2003 }
2004 
2005 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
2006   using Base::Base;
2007 
2008   LogicalResult
2009   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
2010                   ConversionPatternRewriter &rewriter) const override {
2011     if (failed(match(atomicOp)))
2012       return failure();
2013     auto maybeKind = matchSimpleAtomicOp(atomicOp);
2014     if (!maybeKind)
2015       return failure();
2016     auto resultType = adaptor.getValue().getType();
2017     auto memRefType = atomicOp.getMemRefType();
2018     auto dataPtr =
2019         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
2020                              adaptor.getIndices(), rewriter);
2021     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
2022         atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
2023         LLVM::AtomicOrdering::acq_rel);
2024     return success();
2025   }
2026 };
2027 
2028 } // namespace
2029 
2030 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
2031                                                   RewritePatternSet &patterns) {
2032   // clang-format off
2033   patterns.add<
2034       AllocaOpLowering,
2035       AllocaScopeOpLowering,
2036       AtomicRMWOpLowering,
2037       AssumeAlignmentOpLowering,
2038       DimOpLowering,
2039       GenericAtomicRMWOpLowering,
2040       GlobalMemrefOpLowering,
2041       GetGlobalMemrefOpLowering,
2042       LoadOpLowering,
2043       MemRefCastOpLowering,
2044       MemRefCopyOpLowering,
2045       MemRefReinterpretCastOpLowering,
2046       MemRefReshapeOpLowering,
2047       PrefetchOpLowering,
2048       RankOpLowering,
2049       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2050       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2051       StoreOpLowering,
2052       SubViewOpLowering,
2053       TransposeOpLowering,
2054       ViewOpLowering>(converter);
2055   // clang-format on
2056   auto allocLowering = converter.getOptions().allocLowering;
2057   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
2058     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
2059   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2060     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
2061 }
2062 
2063 namespace {
2064 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
2065   MemRefToLLVMPass() = default;
2066 
2067   void runOnOperation() override {
2068     Operation *op = getOperation();
2069     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2070     LowerToLLVMOptions options(&getContext(),
2071                                dataLayoutAnalysis.getAtOrAbove(op));
2072     options.allocLowering =
2073         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
2074                          : LowerToLLVMOptions::AllocLowering::Malloc);
2075 
2076     options.useGenericFunctions = useGenericFunctions;
2077 
2078     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2079       options.overrideIndexBitwidth(indexBitwidth);
2080 
2081     LLVMTypeConverter typeConverter(&getContext(), options,
2082                                     &dataLayoutAnalysis);
2083     RewritePatternSet patterns(&getContext());
2084     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
2085     LLVMConversionTarget target(getContext());
2086     target.addLegalOp<func::FuncOp>();
2087     if (failed(applyPartialConversion(op, target, std::move(patterns))))
2088       signalPassFailure();
2089   }
2090 };
2091 } // namespace
2092 
2093 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
2094   return std::make_unique<MemRefToLLVMPass>();
2095 }
2096