1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "llvm/ADT/SmallBitVector.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 
27 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
28   AllocOpLowering(LLVMTypeConverter &converter)
29       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
30                                 converter) {}
31 
32   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
33                                           Location loc, Value sizeBytes,
34                                           Operation *op) const override {
35     // Heap allocations.
36     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
37     MemRefType memRefType = allocOp.getType();
38 
39     Value alignment;
40     if (auto alignmentAttr = allocOp.alignment()) {
41       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
42     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
43       // In the case where no alignment is specified, we may want to override
44       // `malloc's` behavior. `malloc` typically aligns at the size of the
45       // biggest scalar on a target HW. For non-scalars, use the natural
46       // alignment of the LLVM type given by the LLVM DataLayout.
47       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
48     }
49 
50     if (alignment) {
51       // Adjust the allocation size to consider alignment.
52       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
53     }
54 
55     // Allocate the underlying buffer and store a pointer to it in the MemRef
56     // descriptor.
57     Type elementPtrType = this->getElementPtrType(memRefType);
58     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
59         allocOp->getParentOfType<ModuleOp>(), getIndexType());
60     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
61                                   getVoidPtrType());
62     Value allocatedPtr =
63         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
64 
65     Value alignedPtr = allocatedPtr;
66     if (alignment) {
67       // Compute the aligned type pointer.
68       Value allocatedInt =
69           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
70       Value alignmentInt =
71           createAligned(rewriter, loc, allocatedInt, alignment);
72       alignedPtr =
73           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
74     }
75 
76     return std::make_tuple(allocatedPtr, alignedPtr);
77   }
78 };
79 
80 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
81   AlignedAllocOpLowering(LLVMTypeConverter &converter)
82       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
83                                 converter) {}
84 
85   /// Returns the memref's element size in bytes using the data layout active at
86   /// `op`.
87   // TODO: there are other places where this is used. Expose publicly?
88   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
89     const DataLayout *layout = &defaultLayout;
90     if (const DataLayoutAnalysis *analysis =
91             getTypeConverter()->getDataLayoutAnalysis()) {
92       layout = &analysis->getAbove(op);
93     }
94     Type elementType = memRefType.getElementType();
95     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
96       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
97                                                          *layout);
98     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
99       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
100           memRefElementType, *layout);
101     return layout->getTypeSize(elementType);
102   }
103 
104   /// Returns true if the memref size in bytes is known to be a multiple of
105   /// factor assuming the data layout active at `op`.
106   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
107                               Operation *op) const {
108     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
109     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
110       if (ShapedType::isDynamic(type.getDimSize(i)))
111         continue;
112       sizeDivisor = sizeDivisor * type.getDimSize(i);
113     }
114     return sizeDivisor % factor == 0;
115   }
116 
117   /// Returns the alignment to be used for the allocation call itself.
118   /// aligned_alloc requires the allocation size to be a power of two, and the
119   /// allocation size to be a multiple of alignment,
120   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
121     if (Optional<uint64_t> alignment = allocOp.alignment())
122       return *alignment;
123 
124     // Whenever we don't have alignment set, we will use an alignment
125     // consistent with the element type; since the allocation size has to be a
126     // power of two, we will bump to the next power of two if it already isn't.
127     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
128     return std::max(kMinAlignedAllocAlignment,
129                     llvm::PowerOf2Ceil(eltSizeBytes));
130   }
131 
132   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
133                                           Location loc, Value sizeBytes,
134                                           Operation *op) const override {
135     // Heap allocations.
136     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
137     MemRefType memRefType = allocOp.getType();
138     int64_t alignment = getAllocationAlignment(allocOp);
139     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
140 
141     // aligned_alloc requires size to be a multiple of alignment; we will pad
142     // the size to the next multiple if necessary.
143     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
144       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
145 
146     Type elementPtrType = this->getElementPtrType(memRefType);
147     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
148         allocOp->getParentOfType<ModuleOp>(), getIndexType());
149     auto results =
150         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
151                        getVoidPtrType());
152     Value allocatedPtr =
153         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
154 
155     return std::make_tuple(allocatedPtr, allocatedPtr);
156   }
157 
158   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
159   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
160 
161   /// Default layout to use in absence of the corresponding analysis.
162   DataLayout defaultLayout;
163 };
164 
165 // Out of line definition, required till C++17.
166 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
167 
168 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
169   AllocaOpLowering(LLVMTypeConverter &converter)
170       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
171                                 converter) {}
172 
173   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
174   /// is set to null for stack allocations. `accessAlignment` is set if
175   /// alignment is needed post allocation (for eg. in conjunction with malloc).
176   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
177                                           Location loc, Value sizeBytes,
178                                           Operation *op) const override {
179 
180     // With alloca, one gets a pointer to the element type right away.
181     // For stack allocations.
182     auto allocaOp = cast<memref::AllocaOp>(op);
183     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
184 
185     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
186         loc, elementPtrType, sizeBytes,
187         allocaOp.alignment() ? *allocaOp.alignment() : 0);
188 
189     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
190   }
191 };
192 
193 struct AllocaScopeOpLowering
194     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
195   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
196 
197   LogicalResult
198   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
199                   ConversionPatternRewriter &rewriter) const override {
200     OpBuilder::InsertionGuard guard(rewriter);
201     Location loc = allocaScopeOp.getLoc();
202 
203     // Split the current block before the AllocaScopeOp to create the inlining
204     // point.
205     auto *currentBlock = rewriter.getInsertionBlock();
206     auto *remainingOpsBlock =
207         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
208     Block *continueBlock;
209     if (allocaScopeOp.getNumResults() == 0) {
210       continueBlock = remainingOpsBlock;
211     } else {
212       continueBlock = rewriter.createBlock(
213           remainingOpsBlock, allocaScopeOp.getResultTypes(),
214           SmallVector<Location>(allocaScopeOp->getNumResults(),
215                                 allocaScopeOp.getLoc()));
216       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
217     }
218 
219     // Inline body region.
220     Block *beforeBody = &allocaScopeOp.bodyRegion().front();
221     Block *afterBody = &allocaScopeOp.bodyRegion().back();
222     rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
223 
224     // Save stack and then branch into the body of the region.
225     rewriter.setInsertionPointToEnd(currentBlock);
226     auto stackSaveOp =
227         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
228     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
229 
230     // Replace the alloca_scope return with a branch that jumps out of the body.
231     // Stack restore before leaving the body region.
232     rewriter.setInsertionPointToEnd(afterBody);
233     auto returnOp =
234         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
235     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
236         returnOp, returnOp.results(), continueBlock);
237 
238     // Insert stack restore before jumping out the body of the region.
239     rewriter.setInsertionPoint(branchOp);
240     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
241 
242     // Replace the op with values return from the body region.
243     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
244 
245     return success();
246   }
247 };
248 
249 struct AssumeAlignmentOpLowering
250     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
251   using ConvertOpToLLVMPattern<
252       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
253 
254   LogicalResult
255   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
256                   ConversionPatternRewriter &rewriter) const override {
257     Value memref = adaptor.memref();
258     unsigned alignment = op.alignment();
259     auto loc = op.getLoc();
260 
261     MemRefDescriptor memRefDescriptor(memref);
262     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
263 
264     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
265     // the asserted memref.alignedPtr isn't used anywhere else, as the real
266     // users like load/store/views always re-extract memref.alignedPtr as they
267     // get lowered.
268     //
269     // This relies on LLVM's CSE optimization (potentially after SROA), since
270     // after CSE all memref.alignedPtr instances get de-duplicated into the same
271     // pointer SSA value.
272     auto intPtrType =
273         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
274     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
275     Value mask =
276         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
277     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
278     rewriter.create<LLVM::AssumeOp>(
279         loc, rewriter.create<LLVM::ICmpOp>(
280                  loc, LLVM::ICmpPredicate::eq,
281                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
282 
283     rewriter.eraseOp(op);
284     return success();
285   }
286 };
287 
288 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
289 // The memref descriptor being an SSA value, there is no need to clean it up
290 // in any way.
291 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
292   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
293 
294   explicit DeallocOpLowering(LLVMTypeConverter &converter)
295       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
296 
297   LogicalResult
298   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
299                   ConversionPatternRewriter &rewriter) const override {
300     // Insert the `free` declaration if it is not already present.
301     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
302     MemRefDescriptor memref(adaptor.memref());
303     Value casted = rewriter.create<LLVM::BitcastOp>(
304         op.getLoc(), getVoidPtrType(),
305         memref.allocatedPtr(rewriter, op.getLoc()));
306     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
307         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
308     return success();
309   }
310 };
311 
312 // A `dim` is converted to a constant for static sizes and to an access to the
313 // size stored in the memref descriptor for dynamic sizes.
314 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
315   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
316 
317   LogicalResult
318   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
319                   ConversionPatternRewriter &rewriter) const override {
320     Type operandType = dimOp.source().getType();
321     if (operandType.isa<UnrankedMemRefType>()) {
322       rewriter.replaceOp(
323           dimOp, {extractSizeOfUnrankedMemRef(
324                      operandType, dimOp, adaptor.getOperands(), rewriter)});
325 
326       return success();
327     }
328     if (operandType.isa<MemRefType>()) {
329       rewriter.replaceOp(
330           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
331                                             adaptor.getOperands(), rewriter)});
332       return success();
333     }
334     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
335   }
336 
337 private:
338   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
339                                     OpAdaptor adaptor,
340                                     ConversionPatternRewriter &rewriter) const {
341     Location loc = dimOp.getLoc();
342 
343     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
344     auto scalarMemRefType =
345         MemRefType::get({}, unrankedMemRefType.getElementType());
346     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
347 
348     // Extract pointer to the underlying ranked descriptor and bitcast it to a
349     // memref<element_type> descriptor pointer to minimize the number of GEP
350     // operations.
351     UnrankedMemRefDescriptor unrankedDesc(adaptor.source());
352     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
353     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
354         loc,
355         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
356                                    addressSpace),
357         underlyingRankedDesc);
358 
359     // Get pointer to offset field of memref<element_type> descriptor.
360     Type indexPtrTy = LLVM::LLVMPointerType::get(
361         getTypeConverter()->getIndexType(), addressSpace);
362     Value two = rewriter.create<LLVM::ConstantOp>(
363         loc, typeConverter->convertType(rewriter.getI32Type()),
364         rewriter.getI32IntegerAttr(2));
365     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
366         loc, indexPtrTy, scalarMemRefDescPtr,
367         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
368 
369     // The size value that we have to extract can be obtained using GEPop with
370     // `dimOp.index() + 1` index argument.
371     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
372         loc, createIndexConstant(rewriter, loc, 1), adaptor.index());
373     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
374                                                  ValueRange({idxPlusOne}));
375     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
376   }
377 
378   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
379     if (Optional<int64_t> idx = dimOp.getConstantIndex())
380       return idx;
381 
382     if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>())
383       return constantOp.getValue()
384           .cast<IntegerAttr>()
385           .getValue()
386           .getSExtValue();
387 
388     return llvm::None;
389   }
390 
391   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
392                                   OpAdaptor adaptor,
393                                   ConversionPatternRewriter &rewriter) const {
394     Location loc = dimOp.getLoc();
395 
396     // Take advantage if index is constant.
397     MemRefType memRefType = operandType.cast<MemRefType>();
398     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
399       int64_t i = index.getValue();
400       if (memRefType.isDynamicDim(i)) {
401         // extract dynamic size from the memref descriptor.
402         MemRefDescriptor descriptor(adaptor.source());
403         return descriptor.size(rewriter, loc, i);
404       }
405       // Use constant for static size.
406       int64_t dimSize = memRefType.getDimSize(i);
407       return createIndexConstant(rewriter, loc, dimSize);
408     }
409     Value index = adaptor.index();
410     int64_t rank = memRefType.getRank();
411     MemRefDescriptor memrefDescriptor(adaptor.source());
412     return memrefDescriptor.size(rewriter, loc, index, rank);
413   }
414 };
415 
416 /// Common base for load and store operations on MemRefs. Restricts the match
417 /// to supported MemRef types. Provides functionality to emit code accessing a
418 /// specific element of the underlying data buffer.
419 template <typename Derived>
420 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
421   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
422   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
423   using Base = LoadStoreOpLowering<Derived>;
424 
425   LogicalResult match(Derived op) const override {
426     MemRefType type = op.getMemRefType();
427     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
428   }
429 };
430 
431 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
432 /// retried until it succeeds in atomically storing a new value into memory.
433 ///
434 ///      +---------------------------------+
435 ///      |   <code before the AtomicRMWOp> |
436 ///      |   <compute initial %loaded>     |
437 ///      |   cf.br loop(%loaded)              |
438 ///      +---------------------------------+
439 ///             |
440 ///  -------|   |
441 ///  |      v   v
442 ///  |   +--------------------------------+
443 ///  |   | loop(%loaded):                 |
444 ///  |   |   <body contents>              |
445 ///  |   |   %pair = cmpxchg              |
446 ///  |   |   %ok = %pair[0]               |
447 ///  |   |   %new = %pair[1]              |
448 ///  |   |   cf.cond_br %ok, end, loop(%new) |
449 ///  |   +--------------------------------+
450 ///  |          |        |
451 ///  |-----------        |
452 ///                      v
453 ///      +--------------------------------+
454 ///      | end:                           |
455 ///      |   <code after the AtomicRMWOp> |
456 ///      +--------------------------------+
457 ///
458 struct GenericAtomicRMWOpLowering
459     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
460   using Base::Base;
461 
462   LogicalResult
463   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
464                   ConversionPatternRewriter &rewriter) const override {
465     auto loc = atomicOp.getLoc();
466     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
467 
468     // Split the block into initial, loop, and ending parts.
469     auto *initBlock = rewriter.getInsertionBlock();
470     auto *loopBlock = rewriter.createBlock(
471         initBlock->getParent(), std::next(Region::iterator(initBlock)),
472         valueType, loc);
473     auto *endBlock = rewriter.createBlock(
474         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
475 
476     // Operations range to be moved to `endBlock`.
477     auto opsToMoveStart = atomicOp->getIterator();
478     auto opsToMoveEnd = initBlock->back().getIterator();
479 
480     // Compute the loaded value and branch to the loop block.
481     rewriter.setInsertionPointToEnd(initBlock);
482     auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
483     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
484                                         adaptor.indices(), rewriter);
485     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
486     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
487 
488     // Prepare the body of the loop block.
489     rewriter.setInsertionPointToStart(loopBlock);
490 
491     // Clone the GenericAtomicRMWOp region and extract the result.
492     auto loopArgument = loopBlock->getArgument(0);
493     BlockAndValueMapping mapping;
494     mapping.map(atomicOp.getCurrentValue(), loopArgument);
495     Block &entryBlock = atomicOp.body().front();
496     for (auto &nestedOp : entryBlock.without_terminator()) {
497       Operation *clone = rewriter.clone(nestedOp, mapping);
498       mapping.map(nestedOp.getResults(), clone->getResults());
499     }
500     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
501 
502     // Prepare the epilog of the loop block.
503     // Append the cmpxchg op to the end of the loop block.
504     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
505     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
506     auto boolType = IntegerType::get(rewriter.getContext(), 1);
507     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
508                                                      {valueType, boolType});
509     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
510         loc, pairType, dataPtr, loopArgument, result, successOrdering,
511         failureOrdering);
512     // Extract the %new_loaded and %ok values from the pair.
513     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
514         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
515     Value ok = rewriter.create<LLVM::ExtractValueOp>(
516         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
517 
518     // Conditionally branch to the end or back to the loop depending on %ok.
519     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
520                                     loopBlock, newLoaded);
521 
522     rewriter.setInsertionPointToEnd(endBlock);
523     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
524                  std::next(opsToMoveEnd), rewriter);
525 
526     // The 'result' of the atomic_rmw op is the newly loaded value.
527     rewriter.replaceOp(atomicOp, {newLoaded});
528 
529     return success();
530   }
531 
532 private:
533   // Clones a segment of ops [start, end) and erases the original.
534   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
535                     Block::iterator start, Block::iterator end,
536                     ConversionPatternRewriter &rewriter) const {
537     BlockAndValueMapping mapping;
538     mapping.map(oldResult, newResult);
539     SmallVector<Operation *, 2> opsToErase;
540     for (auto it = start; it != end; ++it) {
541       rewriter.clone(*it, mapping);
542       opsToErase.push_back(&*it);
543     }
544     for (auto *it : opsToErase)
545       rewriter.eraseOp(it);
546   }
547 };
548 
549 /// Returns the LLVM type of the global variable given the memref type `type`.
550 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
551                                           LLVMTypeConverter &typeConverter) {
552   // LLVM type for a global memref will be a multi-dimension array. For
553   // declarations or uninitialized global memrefs, we can potentially flatten
554   // this to a 1D array. However, for memref.global's with an initial value,
555   // we do not intend to flatten the ElementsAttribute when going from std ->
556   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
557   Type elementType = typeConverter.convertType(type.getElementType());
558   Type arrayTy = elementType;
559   // Shape has the outermost dim at index 0, so need to walk it backwards
560   for (int64_t dim : llvm::reverse(type.getShape()))
561     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
562   return arrayTy;
563 }
564 
565 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
566 struct GlobalMemrefOpLowering
567     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
568   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
569 
570   LogicalResult
571   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
572                   ConversionPatternRewriter &rewriter) const override {
573     MemRefType type = global.type();
574     if (!isConvertibleAndHasIdentityMaps(type))
575       return failure();
576 
577     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
578 
579     LLVM::Linkage linkage =
580         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
581 
582     Attribute initialValue = nullptr;
583     if (!global.isExternal() && !global.isUninitialized()) {
584       auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
585       initialValue = elementsAttr;
586 
587       // For scalar memrefs, the global variable created is of the element type,
588       // so unpack the elements attribute to extract the value.
589       if (type.getRank() == 0)
590         initialValue = elementsAttr.getSplatValue<Attribute>();
591     }
592 
593     uint64_t alignment = global.alignment().getValueOr(0);
594 
595     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
596         global, arrayTy, global.constant(), linkage, global.sym_name(),
597         initialValue, alignment, type.getMemorySpaceAsInt());
598     if (!global.isExternal() && global.isUninitialized()) {
599       Block *blk = new Block();
600       newGlobal.getInitializerRegion().push_back(blk);
601       rewriter.setInsertionPointToStart(blk);
602       Value undef[] = {
603           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
604       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
605     }
606     return success();
607   }
608 };
609 
610 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
611 /// the first element stashed into the descriptor. This reuses
612 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
613 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
614   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
615       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
616                                 converter) {}
617 
618   /// Buffer "allocation" for memref.get_global op is getting the address of
619   /// the global variable referenced.
620   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
621                                           Location loc, Value sizeBytes,
622                                           Operation *op) const override {
623     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
624     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
625     unsigned memSpace = type.getMemorySpaceAsInt();
626 
627     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
628     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
629         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
630 
631     // Get the address of the first element in the array by creating a GEP with
632     // the address of the GV as the base, and (rank + 1) number of 0 indices.
633     Type elementType = typeConverter->convertType(type.getElementType());
634     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
635 
636     SmallVector<Value> operands;
637     operands.insert(operands.end(), type.getRank() + 1,
638                     createIndexConstant(rewriter, loc, 0));
639     auto gep =
640         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
641 
642     // We do not expect the memref obtained using `memref.get_global` to be
643     // ever deallocated. Set the allocated pointer to be known bad value to
644     // help debug if that ever happens.
645     auto intPtrType = getIntPtrType(memSpace);
646     Value deadBeefConst =
647         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
648     auto deadBeefPtr =
649         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
650 
651     // Both allocated and aligned pointers are same. We could potentially stash
652     // a nullptr for the allocated pointer since we do not expect any dealloc.
653     return std::make_tuple(deadBeefPtr, gep);
654   }
655 };
656 
657 // Load operation is lowered to obtaining a pointer to the indexed element
658 // and loading it.
659 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
660   using Base::Base;
661 
662   LogicalResult
663   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
664                   ConversionPatternRewriter &rewriter) const override {
665     auto type = loadOp.getMemRefType();
666 
667     Value dataPtr = getStridedElementPtr(
668         loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter);
669     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
670     return success();
671   }
672 };
673 
674 // Store operation is lowered to obtaining a pointer to the indexed element,
675 // and storing the given value to it.
676 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
677   using Base::Base;
678 
679   LogicalResult
680   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
681                   ConversionPatternRewriter &rewriter) const override {
682     auto type = op.getMemRefType();
683 
684     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(),
685                                          adaptor.indices(), rewriter);
686     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr);
687     return success();
688   }
689 };
690 
691 // The prefetch operation is lowered in a way similar to the load operation
692 // except that the llvm.prefetch operation is used for replacement.
693 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
694   using Base::Base;
695 
696   LogicalResult
697   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
698                   ConversionPatternRewriter &rewriter) const override {
699     auto type = prefetchOp.getMemRefType();
700     auto loc = prefetchOp.getLoc();
701 
702     Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(),
703                                          adaptor.indices(), rewriter);
704 
705     // Replace with llvm.prefetch.
706     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
707     auto isWrite = rewriter.create<LLVM::ConstantOp>(
708         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
709     auto localityHint = rewriter.create<LLVM::ConstantOp>(
710         loc, llvmI32Type,
711         rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
712     auto isData = rewriter.create<LLVM::ConstantOp>(
713         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
714 
715     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
716                                                 localityHint, isData);
717     return success();
718   }
719 };
720 
721 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
722   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
723 
724   LogicalResult
725   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
726                   ConversionPatternRewriter &rewriter) const override {
727     Location loc = op.getLoc();
728     Type operandType = op.memref().getType();
729     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
730       UnrankedMemRefDescriptor desc(adaptor.memref());
731       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
732       return success();
733     }
734     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
735       rewriter.replaceOp(
736           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
737       return success();
738     }
739     return failure();
740   }
741 };
742 
743 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
744   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
745 
746   LogicalResult match(memref::CastOp memRefCastOp) const override {
747     Type srcType = memRefCastOp.getOperand().getType();
748     Type dstType = memRefCastOp.getType();
749 
750     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
751     // used for type erasure. For now they must preserve underlying element type
752     // and require source and result type to have the same rank. Therefore,
753     // perform a sanity check that the underlying structs are the same. Once op
754     // semantics are relaxed we can revisit.
755     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
756       return success(typeConverter->convertType(srcType) ==
757                      typeConverter->convertType(dstType));
758 
759     // At least one of the operands is unranked type
760     assert(srcType.isa<UnrankedMemRefType>() ||
761            dstType.isa<UnrankedMemRefType>());
762 
763     // Unranked to unranked cast is disallowed
764     return !(srcType.isa<UnrankedMemRefType>() &&
765              dstType.isa<UnrankedMemRefType>())
766                ? success()
767                : failure();
768   }
769 
770   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
771                ConversionPatternRewriter &rewriter) const override {
772     auto srcType = memRefCastOp.getOperand().getType();
773     auto dstType = memRefCastOp.getType();
774     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
775     auto loc = memRefCastOp.getLoc();
776 
777     // For ranked/ranked case, just keep the original descriptor.
778     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
779       return rewriter.replaceOp(memRefCastOp, {adaptor.source()});
780 
781     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
782       // Casting ranked to unranked memref type
783       // Set the rank in the destination from the memref type
784       // Allocate space on the stack and copy the src memref descriptor
785       // Set the ptr in the destination to the stack space
786       auto srcMemRefType = srcType.cast<MemRefType>();
787       int64_t rank = srcMemRefType.getRank();
788       // ptr = AllocaOp sizeof(MemRefDescriptor)
789       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
790           loc, adaptor.source(), rewriter);
791       // voidptr = BitCastOp srcType* to void*
792       auto voidPtr =
793           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
794               .getResult();
795       // rank = ConstantOp srcRank
796       auto rankVal = rewriter.create<LLVM::ConstantOp>(
797           loc, typeConverter->convertType(rewriter.getIntegerType(64)),
798           rewriter.getI64IntegerAttr(rank));
799       // undef = UndefOp
800       UnrankedMemRefDescriptor memRefDesc =
801           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
802       // d1 = InsertValueOp undef, rank, 0
803       memRefDesc.setRank(rewriter, loc, rankVal);
804       // d2 = InsertValueOp d1, voidptr, 1
805       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
806       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
807 
808     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
809       // Casting from unranked type to ranked.
810       // The operation is assumed to be doing a correct cast. If the destination
811       // type mismatches the unranked the type, it is undefined behavior.
812       UnrankedMemRefDescriptor memRefDesc(adaptor.source());
813       // ptr = ExtractValueOp src, 1
814       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
815       // castPtr = BitCastOp i8* to structTy*
816       auto castPtr =
817           rewriter
818               .create<LLVM::BitcastOp>(
819                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
820               .getResult();
821       // struct = LoadOp castPtr
822       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
823       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
824     } else {
825       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
826     }
827   }
828 };
829 
830 /// Pattern to lower a `memref.copy` to llvm.
831 ///
832 /// For memrefs with identity layouts, the copy is lowered to the llvm
833 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
834 /// to the generic `MemrefCopyFn`.
835 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
836   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
837 
838   LogicalResult
839   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
840                           ConversionPatternRewriter &rewriter) const {
841     auto loc = op.getLoc();
842     auto srcType = op.source().getType().dyn_cast<MemRefType>();
843 
844     MemRefDescriptor srcDesc(adaptor.source());
845 
846     // Compute number of elements.
847     Value numElements = rewriter.create<LLVM::ConstantOp>(
848         loc, getIndexType(), rewriter.getIndexAttr(1));
849     for (int pos = 0; pos < srcType.getRank(); ++pos) {
850       auto size = srcDesc.size(rewriter, loc, pos);
851       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
852     }
853 
854     // Get element size.
855     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
856     // Compute total.
857     Value totalSize =
858         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
859 
860     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
861     MemRefDescriptor targetDesc(adaptor.target());
862     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
863     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
864         loc, typeConverter->convertType(rewriter.getI1Type()),
865         rewriter.getBoolAttr(false));
866     rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize,
867                                     isVolatile);
868     rewriter.eraseOp(op);
869 
870     return success();
871   }
872 
873   LogicalResult
874   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
875                              ConversionPatternRewriter &rewriter) const {
876     auto loc = op.getLoc();
877     auto srcType = op.source().getType().cast<BaseMemRefType>();
878     auto targetType = op.target().getType().cast<BaseMemRefType>();
879 
880     // First make sure we have an unranked memref descriptor representation.
881     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
882       auto rank = rewriter.create<LLVM::ConstantOp>(
883           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
884       auto *typeConverter = getTypeConverter();
885       auto ptr =
886           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
887       auto voidPtr =
888           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
889               .getResult();
890       auto unrankedType =
891           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
892       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
893                                             unrankedType,
894                                             ValueRange{rank, voidPtr});
895     };
896 
897     Value unrankedSource = srcType.hasRank()
898                                ? makeUnranked(adaptor.source(), srcType)
899                                : adaptor.source();
900     Value unrankedTarget = targetType.hasRank()
901                                ? makeUnranked(adaptor.target(), targetType)
902                                : adaptor.target();
903 
904     // Now promote the unranked descriptors to the stack.
905     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
906                                                  rewriter.getIndexAttr(1));
907     auto promote = [&](Value desc) {
908       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
909       auto allocated =
910           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
911       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
912       return allocated;
913     };
914 
915     auto sourcePtr = promote(unrankedSource);
916     auto targetPtr = promote(unrankedTarget);
917 
918     unsigned bitwidth = mlir::DataLayout::closest(op).getTypeSizeInBits(
919         srcType.getElementType());
920     auto elemSize = rewriter.create<LLVM::ConstantOp>(
921         loc, getIndexType(), rewriter.getIndexAttr(bitwidth / 8));
922     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
923         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
924     rewriter.create<LLVM::CallOp>(loc, copyFn,
925                                   ValueRange{elemSize, sourcePtr, targetPtr});
926     rewriter.eraseOp(op);
927 
928     return success();
929   }
930 
931   LogicalResult
932   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
933                   ConversionPatternRewriter &rewriter) const override {
934     auto srcType = op.source().getType().cast<BaseMemRefType>();
935     auto targetType = op.target().getType().cast<BaseMemRefType>();
936 
937     if (srcType.hasRank() &&
938         srcType.cast<MemRefType>().getLayout().isIdentity() &&
939         targetType.hasRank() &&
940         targetType.cast<MemRefType>().getLayout().isIdentity())
941       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
942 
943     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
944   }
945 };
946 
947 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
948 /// memref type. In unranked case, the fields are extracted from the underlying
949 /// ranked descriptor.
950 static void extractPointersAndOffset(Location loc,
951                                      ConversionPatternRewriter &rewriter,
952                                      LLVMTypeConverter &typeConverter,
953                                      Value originalOperand,
954                                      Value convertedOperand,
955                                      Value *allocatedPtr, Value *alignedPtr,
956                                      Value *offset = nullptr) {
957   Type operandType = originalOperand.getType();
958   if (operandType.isa<MemRefType>()) {
959     MemRefDescriptor desc(convertedOperand);
960     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
961     *alignedPtr = desc.alignedPtr(rewriter, loc);
962     if (offset != nullptr)
963       *offset = desc.offset(rewriter, loc);
964     return;
965   }
966 
967   unsigned memorySpace =
968       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
969   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
970   Type llvmElementType = typeConverter.convertType(elementType);
971   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
972       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
973 
974   // Extract pointer to the underlying ranked memref descriptor and cast it to
975   // ElemType**.
976   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
977   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
978 
979   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
980       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
981   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
982       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
983   if (offset != nullptr) {
984     *offset = UnrankedMemRefDescriptor::offset(
985         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
986   }
987 }
988 
989 struct MemRefReinterpretCastOpLowering
990     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
991   using ConvertOpToLLVMPattern<
992       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
993 
994   LogicalResult
995   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
996                   ConversionPatternRewriter &rewriter) const override {
997     Type srcType = castOp.source().getType();
998 
999     Value descriptor;
1000     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1001                                                adaptor, &descriptor)))
1002       return failure();
1003     rewriter.replaceOp(castOp, {descriptor});
1004     return success();
1005   }
1006 
1007 private:
1008   LogicalResult convertSourceMemRefToDescriptor(
1009       ConversionPatternRewriter &rewriter, Type srcType,
1010       memref::ReinterpretCastOp castOp,
1011       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1012     MemRefType targetMemRefType =
1013         castOp.getResult().getType().cast<MemRefType>();
1014     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1015                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1016     if (!llvmTargetDescriptorTy)
1017       return failure();
1018 
1019     // Create descriptor.
1020     Location loc = castOp.getLoc();
1021     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1022 
1023     // Set allocated and aligned pointers.
1024     Value allocatedPtr, alignedPtr;
1025     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1026                              castOp.source(), adaptor.source(), &allocatedPtr,
1027                              &alignedPtr);
1028     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1029     desc.setAlignedPtr(rewriter, loc, alignedPtr);
1030 
1031     // Set offset.
1032     if (castOp.isDynamicOffset(0))
1033       desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
1034     else
1035       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1036 
1037     // Set sizes and strides.
1038     unsigned dynSizeId = 0;
1039     unsigned dynStrideId = 0;
1040     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1041       if (castOp.isDynamicSize(i))
1042         desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
1043       else
1044         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1045 
1046       if (castOp.isDynamicStride(i))
1047         desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
1048       else
1049         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1050     }
1051     *descriptor = desc;
1052     return success();
1053   }
1054 };
1055 
1056 struct MemRefReshapeOpLowering
1057     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1058   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1059 
1060   LogicalResult
1061   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1062                   ConversionPatternRewriter &rewriter) const override {
1063     Type srcType = reshapeOp.source().getType();
1064 
1065     Value descriptor;
1066     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1067                                                adaptor, &descriptor)))
1068       return failure();
1069     rewriter.replaceOp(reshapeOp, {descriptor});
1070     return success();
1071   }
1072 
1073 private:
1074   LogicalResult
1075   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1076                                   Type srcType, memref::ReshapeOp reshapeOp,
1077                                   memref::ReshapeOp::Adaptor adaptor,
1078                                   Value *descriptor) const {
1079     // Conversion for statically-known shape args is performed via
1080     // `memref_reinterpret_cast`.
1081     auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
1082     if (shapeMemRefType.hasStaticShape())
1083       return failure();
1084 
1085     // The shape is a rank-1 tensor with unknown length.
1086     Location loc = reshapeOp.getLoc();
1087     MemRefDescriptor shapeDesc(adaptor.shape());
1088     Value resultRank = shapeDesc.size(rewriter, loc, 0);
1089 
1090     // Extract address space and element type.
1091     auto targetType =
1092         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1093     unsigned addressSpace = targetType.getMemorySpaceAsInt();
1094     Type elementType = targetType.getElementType();
1095 
1096     // Create the unranked memref descriptor that holds the ranked one. The
1097     // inner descriptor is allocated on stack.
1098     auto targetDesc = UnrankedMemRefDescriptor::undef(
1099         rewriter, loc, typeConverter->convertType(targetType));
1100     targetDesc.setRank(rewriter, loc, resultRank);
1101     SmallVector<Value, 4> sizes;
1102     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1103                                            targetDesc, sizes);
1104     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1105         loc, getVoidPtrType(), sizes.front(), llvm::None);
1106     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1107 
1108     // Extract pointers and offset from the source memref.
1109     Value allocatedPtr, alignedPtr, offset;
1110     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1111                              reshapeOp.source(), adaptor.source(),
1112                              &allocatedPtr, &alignedPtr, &offset);
1113 
1114     // Set pointers and offset.
1115     Type llvmElementType = typeConverter->convertType(elementType);
1116     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1117         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1118     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1119                                               elementPtrPtrType, allocatedPtr);
1120     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1121                                             underlyingDescPtr,
1122                                             elementPtrPtrType, alignedPtr);
1123     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1124                                         underlyingDescPtr, elementPtrPtrType,
1125                                         offset);
1126 
1127     // Use the offset pointer as base for further addressing. Copy over the new
1128     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1129     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1130         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1131         elementPtrPtrType);
1132     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1133         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1134     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1135     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1136     Value resultRankMinusOne =
1137         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1138 
1139     Block *initBlock = rewriter.getInsertionBlock();
1140     Type indexType = getTypeConverter()->getIndexType();
1141     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1142 
1143     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1144                                             {indexType, indexType}, {loc, loc});
1145 
1146     // Move the remaining initBlock ops to condBlock.
1147     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1148     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1149 
1150     rewriter.setInsertionPointToEnd(initBlock);
1151     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1152                                 condBlock);
1153     rewriter.setInsertionPointToStart(condBlock);
1154     Value indexArg = condBlock->getArgument(0);
1155     Value strideArg = condBlock->getArgument(1);
1156 
1157     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1158     Value pred = rewriter.create<LLVM::ICmpOp>(
1159         loc, IntegerType::get(rewriter.getContext(), 1),
1160         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1161 
1162     Block *bodyBlock =
1163         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1164     rewriter.setInsertionPointToStart(bodyBlock);
1165 
1166     // Copy size from shape to descriptor.
1167     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1168     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1169         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1170     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1171     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1172                                       targetSizesBase, indexArg, size);
1173 
1174     // Write stride value and compute next one.
1175     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1176                                         targetStridesBase, indexArg, strideArg);
1177     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1178 
1179     // Decrement loop counter and branch back.
1180     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1181     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1182                                 condBlock);
1183 
1184     Block *remainder =
1185         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1186 
1187     // Hook up the cond exit to the remainder.
1188     rewriter.setInsertionPointToEnd(condBlock);
1189     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1190                                     llvm::None);
1191 
1192     // Reset position to beginning of new remainder block.
1193     rewriter.setInsertionPointToStart(remainder);
1194 
1195     *descriptor = targetDesc;
1196     return success();
1197   }
1198 };
1199 
1200 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1201 /// `Value`s.
1202 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1203                                       Type &llvmIndexType,
1204                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1205   return llvm::to_vector<4>(
1206       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1207         if (auto attr = value.dyn_cast<Attribute>())
1208           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1209         return value.get<Value>();
1210       }));
1211 }
1212 
1213 /// Compute a map that for a given dimension of the expanded type gives the
1214 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1215 /// the `reassocation` maps.
1216 static DenseMap<int64_t, int64_t>
1217 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1218   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1219   for (auto &en : enumerate(reassociation)) {
1220     for (auto dim : en.value())
1221       expandedDimToCollapsedDim[dim] = en.index();
1222   }
1223   return expandedDimToCollapsedDim;
1224 }
1225 
1226 static OpFoldResult
1227 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1228                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1229                          MemRefDescriptor &inDesc,
1230                          ArrayRef<int64_t> inStaticShape,
1231                          ArrayRef<ReassociationIndices> reassocation,
1232                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1233   int64_t outDimSize = outStaticShape[outDimIndex];
1234   if (!ShapedType::isDynamic(outDimSize))
1235     return b.getIndexAttr(outDimSize);
1236 
1237   // Calculate the multiplication of all the out dim sizes except the
1238   // current dim.
1239   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1240   int64_t otherDimSizesMul = 1;
1241   for (auto otherDimIndex : reassocation[inDimIndex]) {
1242     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1243       continue;
1244     int64_t otherDimSize = outStaticShape[otherDimIndex];
1245     assert(!ShapedType::isDynamic(otherDimSize) &&
1246            "single dimension cannot be expanded into multiple dynamic "
1247            "dimensions");
1248     otherDimSizesMul *= otherDimSize;
1249   }
1250 
1251   // outDimSize = inDimSize / otherOutDimSizesMul
1252   int64_t inDimSize = inStaticShape[inDimIndex];
1253   Value inDimSizeDynamic =
1254       ShapedType::isDynamic(inDimSize)
1255           ? inDesc.size(b, loc, inDimIndex)
1256           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1257                                        b.getIndexAttr(inDimSize));
1258   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1259       loc, inDimSizeDynamic,
1260       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1261                                  b.getIndexAttr(otherDimSizesMul)));
1262   return outDimSizeDynamic;
1263 }
1264 
1265 static OpFoldResult getCollapsedOutputDimSize(
1266     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1267     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1268     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1269   if (!ShapedType::isDynamic(outDimSize))
1270     return b.getIndexAttr(outDimSize);
1271 
1272   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1273   Value outDimSizeDynamic = c1;
1274   for (auto inDimIndex : reassocation[outDimIndex]) {
1275     int64_t inDimSize = inStaticShape[inDimIndex];
1276     Value inDimSizeDynamic =
1277         ShapedType::isDynamic(inDimSize)
1278             ? inDesc.size(b, loc, inDimIndex)
1279             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1280                                          b.getIndexAttr(inDimSize));
1281     outDimSizeDynamic =
1282         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1283   }
1284   return outDimSizeDynamic;
1285 }
1286 
1287 static SmallVector<OpFoldResult, 4>
1288 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1289                         ArrayRef<ReassociationIndices> reassocation,
1290                         ArrayRef<int64_t> inStaticShape,
1291                         MemRefDescriptor &inDesc,
1292                         ArrayRef<int64_t> outStaticShape) {
1293   return llvm::to_vector<4>(llvm::map_range(
1294       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1295         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1296                                          outStaticShape[outDimIndex],
1297                                          inStaticShape, inDesc, reassocation);
1298       }));
1299 }
1300 
1301 static SmallVector<OpFoldResult, 4>
1302 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1303                        ArrayRef<ReassociationIndices> reassocation,
1304                        ArrayRef<int64_t> inStaticShape,
1305                        MemRefDescriptor &inDesc,
1306                        ArrayRef<int64_t> outStaticShape) {
1307   DenseMap<int64_t, int64_t> outDimToInDimMap =
1308       getExpandedDimToCollapsedDimMap(reassocation);
1309   return llvm::to_vector<4>(llvm::map_range(
1310       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1311         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1312                                         outStaticShape, inDesc, inStaticShape,
1313                                         reassocation, outDimToInDimMap);
1314       }));
1315 }
1316 
1317 static SmallVector<Value>
1318 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1319                       ArrayRef<ReassociationIndices> reassocation,
1320                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1321                       ArrayRef<int64_t> outStaticShape) {
1322   return outStaticShape.size() < inStaticShape.size()
1323              ? getAsValues(b, loc, llvmIndexType,
1324                            getCollapsedOutputShape(b, loc, llvmIndexType,
1325                                                    reassocation, inStaticShape,
1326                                                    inDesc, outStaticShape))
1327              : getAsValues(b, loc, llvmIndexType,
1328                            getExpandedOutputShape(b, loc, llvmIndexType,
1329                                                   reassocation, inStaticShape,
1330                                                   inDesc, outStaticShape));
1331 }
1332 
1333 // ReshapeOp creates a new view descriptor of the proper rank.
1334 // For now, the only conversion supported is for target MemRef with static sizes
1335 // and strides.
1336 template <typename ReshapeOp>
1337 class ReassociatingReshapeOpConversion
1338     : public ConvertOpToLLVMPattern<ReshapeOp> {
1339 public:
1340   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1341   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1342 
1343   LogicalResult
1344   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1345                   ConversionPatternRewriter &rewriter) const override {
1346     MemRefType dstType = reshapeOp.getResultType();
1347     MemRefType srcType = reshapeOp.getSrcType();
1348 
1349     // The condition on the layouts can be ignored when all shapes are static.
1350     if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) {
1351       if (!srcType.getLayout().isIdentity() ||
1352           !dstType.getLayout().isIdentity()) {
1353         return rewriter.notifyMatchFailure(
1354             reshapeOp, "only empty layout map is supported");
1355       }
1356     }
1357 
1358     int64_t offset;
1359     SmallVector<int64_t, 4> strides;
1360     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1361       return rewriter.notifyMatchFailure(
1362           reshapeOp, "failed to get stride and offset exprs");
1363     }
1364 
1365     MemRefDescriptor srcDesc(adaptor.src());
1366     Location loc = reshapeOp->getLoc();
1367     auto dstDesc = MemRefDescriptor::undef(
1368         rewriter, loc, this->typeConverter->convertType(dstType));
1369     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1370     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1371     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1372 
1373     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1374     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1375     Type llvmIndexType =
1376         this->typeConverter->convertType(rewriter.getIndexType());
1377     SmallVector<Value> dstShape = getDynamicOutputShape(
1378         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1379         srcStaticShape, srcDesc, dstStaticShape);
1380     for (auto &en : llvm::enumerate(dstShape))
1381       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1382 
1383     auto isStaticStride = [](int64_t stride) {
1384       return !ShapedType::isDynamicStrideOrOffset(stride);
1385     };
1386     if (llvm::all_of(strides, isStaticStride)) {
1387       for (auto &en : llvm::enumerate(strides))
1388         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1389     } else {
1390       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1391                                                    rewriter.getIndexAttr(1));
1392       Value stride = c1;
1393       for (auto dimIndex :
1394            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1395         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1396         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1397       }
1398     }
1399     rewriter.replaceOp(reshapeOp, {dstDesc});
1400     return success();
1401   }
1402 };
1403 
1404 /// Conversion pattern that transforms a subview op into:
1405 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1406 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1407 ///      and stride.
1408 /// The subview op is replaced by the descriptor.
1409 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1410   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1411 
1412   LogicalResult
1413   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1414                   ConversionPatternRewriter &rewriter) const override {
1415     auto loc = subViewOp.getLoc();
1416 
1417     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
1418     auto sourceElementTy =
1419         typeConverter->convertType(sourceMemRefType.getElementType());
1420 
1421     auto viewMemRefType = subViewOp.getType();
1422     auto inferredType = memref::SubViewOp::inferResultType(
1423                             subViewOp.getSourceType(),
1424                             extractFromI64ArrayAttr(subViewOp.static_offsets()),
1425                             extractFromI64ArrayAttr(subViewOp.static_sizes()),
1426                             extractFromI64ArrayAttr(subViewOp.static_strides()))
1427                             .cast<MemRefType>();
1428     auto targetElementTy =
1429         typeConverter->convertType(viewMemRefType.getElementType());
1430     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1431     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1432         !LLVM::isCompatibleType(sourceElementTy) ||
1433         !LLVM::isCompatibleType(targetElementTy) ||
1434         !LLVM::isCompatibleType(targetDescTy))
1435       return failure();
1436 
1437     // Extract the offset and strides from the type.
1438     int64_t offset;
1439     SmallVector<int64_t, 4> strides;
1440     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1441     if (failed(successStrides))
1442       return failure();
1443 
1444     // Create the descriptor.
1445     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1446       return failure();
1447     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1448     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1449 
1450     // Copy the buffer pointer from the old descriptor to the new one.
1451     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1452     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1453         loc,
1454         LLVM::LLVMPointerType::get(targetElementTy,
1455                                    viewMemRefType.getMemorySpaceAsInt()),
1456         extracted);
1457     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1458 
1459     // Copy the aligned pointer from the old descriptor to the new one.
1460     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1461     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1462         loc,
1463         LLVM::LLVMPointerType::get(targetElementTy,
1464                                    viewMemRefType.getMemorySpaceAsInt()),
1465         extracted);
1466     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1467 
1468     size_t inferredShapeRank = inferredType.getRank();
1469     size_t resultShapeRank = viewMemRefType.getRank();
1470 
1471     // Extract strides needed to compute offset.
1472     SmallVector<Value, 4> strideValues;
1473     strideValues.reserve(inferredShapeRank);
1474     for (unsigned i = 0; i < inferredShapeRank; ++i)
1475       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1476 
1477     // Offset.
1478     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1479     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1480       targetMemRef.setConstantOffset(rewriter, loc, offset);
1481     } else {
1482       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1483       // `inferredShapeRank` may be larger than the number of offset operands
1484       // because of trailing semantics. In this case, the offset is guaranteed
1485       // to be interpreted as 0 and we can just skip the extra dimensions.
1486       for (unsigned i = 0, e = std::min(inferredShapeRank,
1487                                         subViewOp.getMixedOffsets().size());
1488            i < e; ++i) {
1489         Value offset =
1490             // TODO: need OpFoldResult ODS adaptor to clean this up.
1491             subViewOp.isDynamicOffset(i)
1492                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1493                 : rewriter.create<LLVM::ConstantOp>(
1494                       loc, llvmIndexType,
1495                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1496         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1497         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1498       }
1499       targetMemRef.setOffset(rewriter, loc, baseOffset);
1500     }
1501 
1502     // Update sizes and strides.
1503     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1504     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1505     assert(mixedSizes.size() == mixedStrides.size() &&
1506            "expected sizes and strides of equal length");
1507     llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1508     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1509          i >= 0 && j >= 0; --i) {
1510       if (unusedDims.test(i))
1511         continue;
1512 
1513       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1514       // In this case, the size is guaranteed to be interpreted as Dim and the
1515       // stride as 1.
1516       Value size, stride;
1517       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1518         // If the static size is available, use it directly. This is similar to
1519         // the folding of dim(constant-op) but removes the need for dim to be
1520         // aware of LLVM constants and for this pass to be aware of std
1521         // constants.
1522         int64_t staticSize =
1523             subViewOp.source().getType().cast<MemRefType>().getShape()[i];
1524         if (staticSize != ShapedType::kDynamicSize) {
1525           size = rewriter.create<LLVM::ConstantOp>(
1526               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1527         } else {
1528           Value pos = rewriter.create<LLVM::ConstantOp>(
1529               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1530           Value dim =
1531               rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos);
1532           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1533               loc, llvmIndexType, dim);
1534           size = cast.getResult(0);
1535         }
1536         stride = rewriter.create<LLVM::ConstantOp>(
1537             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1538       } else {
1539         // TODO: need OpFoldResult ODS adaptor to clean this up.
1540         size =
1541             subViewOp.isDynamicSize(i)
1542                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1543                 : rewriter.create<LLVM::ConstantOp>(
1544                       loc, llvmIndexType,
1545                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1546         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1547           stride = rewriter.create<LLVM::ConstantOp>(
1548               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1549         } else {
1550           stride =
1551               subViewOp.isDynamicStride(i)
1552                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1553                   : rewriter.create<LLVM::ConstantOp>(
1554                         loc, llvmIndexType,
1555                         rewriter.getI64IntegerAttr(
1556                             subViewOp.getStaticStride(i)));
1557           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1558         }
1559       }
1560       targetMemRef.setSize(rewriter, loc, j, size);
1561       targetMemRef.setStride(rewriter, loc, j, stride);
1562       j--;
1563     }
1564 
1565     rewriter.replaceOp(subViewOp, {targetMemRef});
1566     return success();
1567   }
1568 };
1569 
1570 /// Conversion pattern that transforms a transpose op into:
1571 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1572 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1573 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1574 ///      and stride. Size and stride are permutations of the original values.
1575 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1576 /// The transpose op is replaced by the alloca'ed pointer.
1577 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1578 public:
1579   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1580 
1581   LogicalResult
1582   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1583                   ConversionPatternRewriter &rewriter) const override {
1584     auto loc = transposeOp.getLoc();
1585     MemRefDescriptor viewMemRef(adaptor.in());
1586 
1587     // No permutation, early exit.
1588     if (transposeOp.permutation().isIdentity())
1589       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1590 
1591     auto targetMemRef = MemRefDescriptor::undef(
1592         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1593 
1594     // Copy the base and aligned pointers from the old descriptor to the new
1595     // one.
1596     targetMemRef.setAllocatedPtr(rewriter, loc,
1597                                  viewMemRef.allocatedPtr(rewriter, loc));
1598     targetMemRef.setAlignedPtr(rewriter, loc,
1599                                viewMemRef.alignedPtr(rewriter, loc));
1600 
1601     // Copy the offset pointer from the old descriptor to the new one.
1602     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1603 
1604     // Iterate over the dimensions and apply size/stride permutation.
1605     for (const auto &en :
1606          llvm::enumerate(transposeOp.permutation().getResults())) {
1607       int sourcePos = en.index();
1608       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1609       targetMemRef.setSize(rewriter, loc, targetPos,
1610                            viewMemRef.size(rewriter, loc, sourcePos));
1611       targetMemRef.setStride(rewriter, loc, targetPos,
1612                              viewMemRef.stride(rewriter, loc, sourcePos));
1613     }
1614 
1615     rewriter.replaceOp(transposeOp, {targetMemRef});
1616     return success();
1617   }
1618 };
1619 
1620 /// Conversion pattern that transforms an op into:
1621 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1622 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1623 ///      and stride.
1624 /// The view op is replaced by the descriptor.
1625 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1626   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1627 
1628   // Build and return the value for the idx^th shape dimension, either by
1629   // returning the constant shape dimension or counting the proper dynamic size.
1630   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1631                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1632                 unsigned idx) const {
1633     assert(idx < shape.size());
1634     if (!ShapedType::isDynamic(shape[idx]))
1635       return createIndexConstant(rewriter, loc, shape[idx]);
1636     // Count the number of dynamic dims in range [0, idx]
1637     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1638       return ShapedType::isDynamic(v);
1639     });
1640     return dynamicSizes[nDynamic];
1641   }
1642 
1643   // Build and return the idx^th stride, either by returning the constant stride
1644   // or by computing the dynamic stride from the current `runningStride` and
1645   // `nextSize`. The caller should keep a running stride and update it with the
1646   // result returned by this function.
1647   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1648                   ArrayRef<int64_t> strides, Value nextSize,
1649                   Value runningStride, unsigned idx) const {
1650     assert(idx < strides.size());
1651     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1652       return createIndexConstant(rewriter, loc, strides[idx]);
1653     if (nextSize)
1654       return runningStride
1655                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1656                  : nextSize;
1657     assert(!runningStride);
1658     return createIndexConstant(rewriter, loc, 1);
1659   }
1660 
1661   LogicalResult
1662   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1663                   ConversionPatternRewriter &rewriter) const override {
1664     auto loc = viewOp.getLoc();
1665 
1666     auto viewMemRefType = viewOp.getType();
1667     auto targetElementTy =
1668         typeConverter->convertType(viewMemRefType.getElementType());
1669     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1670     if (!targetDescTy || !targetElementTy ||
1671         !LLVM::isCompatibleType(targetElementTy) ||
1672         !LLVM::isCompatibleType(targetDescTy))
1673       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1674              failure();
1675 
1676     int64_t offset;
1677     SmallVector<int64_t, 4> strides;
1678     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1679     if (failed(successStrides))
1680       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1681     assert(offset == 0 && "expected offset to be 0");
1682 
1683     // Create the descriptor.
1684     MemRefDescriptor sourceMemRef(adaptor.source());
1685     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1686 
1687     // Field 1: Copy the allocated pointer, used for malloc/free.
1688     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1689     auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
1690     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1691         loc,
1692         LLVM::LLVMPointerType::get(targetElementTy,
1693                                    srcMemRefType.getMemorySpaceAsInt()),
1694         allocatedPtr);
1695     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1696 
1697     // Field 2: Copy the actual aligned pointer to payload.
1698     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1699     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
1700                                               alignedPtr, adaptor.byte_shift());
1701     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1702         loc,
1703         LLVM::LLVMPointerType::get(targetElementTy,
1704                                    srcMemRefType.getMemorySpaceAsInt()),
1705         alignedPtr);
1706     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1707 
1708     // Field 3: The offset in the resulting type must be 0. This is because of
1709     // the type change: an offset on srcType* may not be expressible as an
1710     // offset on dstType*.
1711     targetMemRef.setOffset(rewriter, loc,
1712                            createIndexConstant(rewriter, loc, offset));
1713 
1714     // Early exit for 0-D corner case.
1715     if (viewMemRefType.getRank() == 0)
1716       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1717 
1718     // Fields 4 and 5: Update sizes and strides.
1719     if (strides.back() != 1)
1720       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1721              failure();
1722     Value stride = nullptr, nextSize = nullptr;
1723     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1724       // Update size.
1725       Value size =
1726           getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
1727       targetMemRef.setSize(rewriter, loc, i, size);
1728       // Update stride.
1729       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1730       targetMemRef.setStride(rewriter, loc, i, stride);
1731       nextSize = size;
1732     }
1733 
1734     rewriter.replaceOp(viewOp, {targetMemRef});
1735     return success();
1736   }
1737 };
1738 
1739 //===----------------------------------------------------------------------===//
1740 // AtomicRMWOpLowering
1741 //===----------------------------------------------------------------------===//
1742 
1743 /// Try to match the kind of a std.atomic_rmw to determine whether to use a
1744 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1745 static Optional<LLVM::AtomicBinOp>
1746 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1747   switch (atomicOp.kind()) {
1748   case arith::AtomicRMWKind::addf:
1749     return LLVM::AtomicBinOp::fadd;
1750   case arith::AtomicRMWKind::addi:
1751     return LLVM::AtomicBinOp::add;
1752   case arith::AtomicRMWKind::assign:
1753     return LLVM::AtomicBinOp::xchg;
1754   case arith::AtomicRMWKind::maxs:
1755     return LLVM::AtomicBinOp::max;
1756   case arith::AtomicRMWKind::maxu:
1757     return LLVM::AtomicBinOp::umax;
1758   case arith::AtomicRMWKind::mins:
1759     return LLVM::AtomicBinOp::min;
1760   case arith::AtomicRMWKind::minu:
1761     return LLVM::AtomicBinOp::umin;
1762   case arith::AtomicRMWKind::ori:
1763     return LLVM::AtomicBinOp::_or;
1764   case arith::AtomicRMWKind::andi:
1765     return LLVM::AtomicBinOp::_and;
1766   default:
1767     return llvm::None;
1768   }
1769   llvm_unreachable("Invalid AtomicRMWKind");
1770 }
1771 
1772 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1773   using Base::Base;
1774 
1775   LogicalResult
1776   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1777                   ConversionPatternRewriter &rewriter) const override {
1778     if (failed(match(atomicOp)))
1779       return failure();
1780     auto maybeKind = matchSimpleAtomicOp(atomicOp);
1781     if (!maybeKind)
1782       return failure();
1783     auto resultType = adaptor.value().getType();
1784     auto memRefType = atomicOp.getMemRefType();
1785     auto dataPtr =
1786         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
1787                              adaptor.indices(), rewriter);
1788     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1789         atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
1790         LLVM::AtomicOrdering::acq_rel);
1791     return success();
1792   }
1793 };
1794 
1795 } // namespace
1796 
1797 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
1798                                                   RewritePatternSet &patterns) {
1799   // clang-format off
1800   patterns.add<
1801       AllocaOpLowering,
1802       AllocaScopeOpLowering,
1803       AtomicRMWOpLowering,
1804       AssumeAlignmentOpLowering,
1805       DimOpLowering,
1806       GenericAtomicRMWOpLowering,
1807       GlobalMemrefOpLowering,
1808       GetGlobalMemrefOpLowering,
1809       LoadOpLowering,
1810       MemRefCastOpLowering,
1811       MemRefCopyOpLowering,
1812       MemRefReinterpretCastOpLowering,
1813       MemRefReshapeOpLowering,
1814       PrefetchOpLowering,
1815       RankOpLowering,
1816       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1817       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1818       StoreOpLowering,
1819       SubViewOpLowering,
1820       TransposeOpLowering,
1821       ViewOpLowering>(converter);
1822   // clang-format on
1823   auto allocLowering = converter.getOptions().allocLowering;
1824   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1825     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1826   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1827     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1828 }
1829 
1830 namespace {
1831 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
1832   MemRefToLLVMPass() = default;
1833 
1834   void runOnOperation() override {
1835     Operation *op = getOperation();
1836     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1837     LowerToLLVMOptions options(&getContext(),
1838                                dataLayoutAnalysis.getAtOrAbove(op));
1839     options.allocLowering =
1840         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1841                          : LowerToLLVMOptions::AllocLowering::Malloc);
1842     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1843       options.overrideIndexBitwidth(indexBitwidth);
1844 
1845     LLVMTypeConverter typeConverter(&getContext(), options,
1846                                     &dataLayoutAnalysis);
1847     RewritePatternSet patterns(&getContext());
1848     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
1849     LLVMConversionTarget target(getContext());
1850     target.addLegalOp<FuncOp>();
1851     if (failed(applyPartialConversion(op, target, std::move(patterns))))
1852       signalPassFailure();
1853   }
1854 };
1855 } // namespace
1856 
1857 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
1858   return std::make_unique<MemRefToLLVMPass>();
1859 }
1860