1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
27   AllocOpLowering(LLVMTypeConverter &converter)
28       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
29                                 converter) {}
30 
31   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
32                                           Location loc, Value sizeBytes,
33                                           Operation *op) const override {
34     // Heap allocations.
35     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
36     MemRefType memRefType = allocOp.getType();
37 
38     Value alignment;
39     if (auto alignmentAttr = allocOp.alignment()) {
40       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
41     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
42       // In the case where no alignment is specified, we may want to override
43       // `malloc's` behavior. `malloc` typically aligns at the size of the
44       // biggest scalar on a target HW. For non-scalars, use the natural
45       // alignment of the LLVM type given by the LLVM DataLayout.
46       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
47     }
48 
49     if (alignment) {
50       // Adjust the allocation size to consider alignment.
51       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
52     }
53 
54     // Allocate the underlying buffer and store a pointer to it in the MemRef
55     // descriptor.
56     Type elementPtrType = this->getElementPtrType(memRefType);
57     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
58         allocOp->getParentOfType<ModuleOp>(), getIndexType());
59     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
60                                   getVoidPtrType());
61     Value allocatedPtr =
62         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
63 
64     Value alignedPtr = allocatedPtr;
65     if (alignment) {
66       // Compute the aligned type pointer.
67       Value allocatedInt =
68           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
69       Value alignmentInt =
70           createAligned(rewriter, loc, allocatedInt, alignment);
71       alignedPtr =
72           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
73     }
74 
75     return std::make_tuple(allocatedPtr, alignedPtr);
76   }
77 };
78 
79 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
80   AlignedAllocOpLowering(LLVMTypeConverter &converter)
81       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
82                                 converter) {}
83 
84   /// Returns the memref's element size in bytes using the data layout active at
85   /// `op`.
86   // TODO: there are other places where this is used. Expose publicly?
87   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
88     const DataLayout *layout = &defaultLayout;
89     if (const DataLayoutAnalysis *analysis =
90             getTypeConverter()->getDataLayoutAnalysis()) {
91       layout = &analysis->getAbove(op);
92     }
93     Type elementType = memRefType.getElementType();
94     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
95       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
96                                                          *layout);
97     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
98       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
99           memRefElementType, *layout);
100     return layout->getTypeSize(elementType);
101   }
102 
103   /// Returns true if the memref size in bytes is known to be a multiple of
104   /// factor assuming the data layout active at `op`.
105   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
106                               Operation *op) const {
107     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
108     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
109       if (ShapedType::isDynamic(type.getDimSize(i)))
110         continue;
111       sizeDivisor = sizeDivisor * type.getDimSize(i);
112     }
113     return sizeDivisor % factor == 0;
114   }
115 
116   /// Returns the alignment to be used for the allocation call itself.
117   /// aligned_alloc requires the allocation size to be a power of two, and the
118   /// allocation size to be a multiple of alignment,
119   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
120     if (Optional<uint64_t> alignment = allocOp.alignment())
121       return *alignment;
122 
123     // Whenever we don't have alignment set, we will use an alignment
124     // consistent with the element type; since the allocation size has to be a
125     // power of two, we will bump to the next power of two if it already isn't.
126     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
127     return std::max(kMinAlignedAllocAlignment,
128                     llvm::PowerOf2Ceil(eltSizeBytes));
129   }
130 
131   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
132                                           Location loc, Value sizeBytes,
133                                           Operation *op) const override {
134     // Heap allocations.
135     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
136     MemRefType memRefType = allocOp.getType();
137     int64_t alignment = getAllocationAlignment(allocOp);
138     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
139 
140     // aligned_alloc requires size to be a multiple of alignment; we will pad
141     // the size to the next multiple if necessary.
142     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
143       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
144 
145     Type elementPtrType = this->getElementPtrType(memRefType);
146     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
147         allocOp->getParentOfType<ModuleOp>(), getIndexType());
148     auto results =
149         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
150                        getVoidPtrType());
151     Value allocatedPtr =
152         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
153 
154     return std::make_tuple(allocatedPtr, allocatedPtr);
155   }
156 
157   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
158   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
159 
160   /// Default layout to use in absence of the corresponding analysis.
161   DataLayout defaultLayout;
162 };
163 
164 // Out of line definition, required till C++17.
165 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
166 
167 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
168   AllocaOpLowering(LLVMTypeConverter &converter)
169       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
170                                 converter) {}
171 
172   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
173   /// is set to null for stack allocations. `accessAlignment` is set if
174   /// alignment is needed post allocation (for eg. in conjunction with malloc).
175   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
176                                           Location loc, Value sizeBytes,
177                                           Operation *op) const override {
178 
179     // With alloca, one gets a pointer to the element type right away.
180     // For stack allocations.
181     auto allocaOp = cast<memref::AllocaOp>(op);
182     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
183 
184     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
185         loc, elementPtrType, sizeBytes,
186         allocaOp.alignment() ? *allocaOp.alignment() : 0);
187 
188     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
189   }
190 };
191 
192 struct AllocaScopeOpLowering
193     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
194   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
195 
196   LogicalResult
197   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
198                   ConversionPatternRewriter &rewriter) const override {
199     OpBuilder::InsertionGuard guard(rewriter);
200     Location loc = allocaScopeOp.getLoc();
201 
202     // Split the current block before the AllocaScopeOp to create the inlining
203     // point.
204     auto *currentBlock = rewriter.getInsertionBlock();
205     auto *remainingOpsBlock =
206         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
207     Block *continueBlock;
208     if (allocaScopeOp.getNumResults() == 0) {
209       continueBlock = remainingOpsBlock;
210     } else {
211       continueBlock = rewriter.createBlock(
212           remainingOpsBlock, allocaScopeOp.getResultTypes(),
213           SmallVector<Location>(allocaScopeOp->getNumResults(),
214                                 allocaScopeOp.getLoc()));
215       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
216     }
217 
218     // Inline body region.
219     Block *beforeBody = &allocaScopeOp.bodyRegion().front();
220     Block *afterBody = &allocaScopeOp.bodyRegion().back();
221     rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
222 
223     // Save stack and then branch into the body of the region.
224     rewriter.setInsertionPointToEnd(currentBlock);
225     auto stackSaveOp =
226         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
227     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
228 
229     // Replace the alloca_scope return with a branch that jumps out of the body.
230     // Stack restore before leaving the body region.
231     rewriter.setInsertionPointToEnd(afterBody);
232     auto returnOp =
233         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
234     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
235         returnOp, returnOp.results(), continueBlock);
236 
237     // Insert stack restore before jumping out the body of the region.
238     rewriter.setInsertionPoint(branchOp);
239     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
240 
241     // Replace the op with values return from the body region.
242     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
243 
244     return success();
245   }
246 };
247 
248 struct AssumeAlignmentOpLowering
249     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
250   using ConvertOpToLLVMPattern<
251       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
252 
253   LogicalResult
254   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
255                   ConversionPatternRewriter &rewriter) const override {
256     Value memref = adaptor.memref();
257     unsigned alignment = op.alignment();
258     auto loc = op.getLoc();
259 
260     MemRefDescriptor memRefDescriptor(memref);
261     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
262 
263     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
264     // the asserted memref.alignedPtr isn't used anywhere else, as the real
265     // users like load/store/views always re-extract memref.alignedPtr as they
266     // get lowered.
267     //
268     // This relies on LLVM's CSE optimization (potentially after SROA), since
269     // after CSE all memref.alignedPtr instances get de-duplicated into the same
270     // pointer SSA value.
271     auto intPtrType =
272         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
273     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
274     Value mask =
275         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
276     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
277     rewriter.create<LLVM::AssumeOp>(
278         loc, rewriter.create<LLVM::ICmpOp>(
279                  loc, LLVM::ICmpPredicate::eq,
280                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
281 
282     rewriter.eraseOp(op);
283     return success();
284   }
285 };
286 
287 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
288 // The memref descriptor being an SSA value, there is no need to clean it up
289 // in any way.
290 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
291   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
292 
293   explicit DeallocOpLowering(LLVMTypeConverter &converter)
294       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
295 
296   LogicalResult
297   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
298                   ConversionPatternRewriter &rewriter) const override {
299     // Insert the `free` declaration if it is not already present.
300     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
301     MemRefDescriptor memref(adaptor.memref());
302     Value casted = rewriter.create<LLVM::BitcastOp>(
303         op.getLoc(), getVoidPtrType(),
304         memref.allocatedPtr(rewriter, op.getLoc()));
305     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
306         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
307     return success();
308   }
309 };
310 
311 // A `dim` is converted to a constant for static sizes and to an access to the
312 // size stored in the memref descriptor for dynamic sizes.
313 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
314   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
315 
316   LogicalResult
317   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
318                   ConversionPatternRewriter &rewriter) const override {
319     Type operandType = dimOp.source().getType();
320     if (operandType.isa<UnrankedMemRefType>()) {
321       rewriter.replaceOp(
322           dimOp, {extractSizeOfUnrankedMemRef(
323                      operandType, dimOp, adaptor.getOperands(), rewriter)});
324 
325       return success();
326     }
327     if (operandType.isa<MemRefType>()) {
328       rewriter.replaceOp(
329           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
330                                             adaptor.getOperands(), rewriter)});
331       return success();
332     }
333     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
334   }
335 
336 private:
337   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
338                                     OpAdaptor adaptor,
339                                     ConversionPatternRewriter &rewriter) const {
340     Location loc = dimOp.getLoc();
341 
342     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
343     auto scalarMemRefType =
344         MemRefType::get({}, unrankedMemRefType.getElementType());
345     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
346 
347     // Extract pointer to the underlying ranked descriptor and bitcast it to a
348     // memref<element_type> descriptor pointer to minimize the number of GEP
349     // operations.
350     UnrankedMemRefDescriptor unrankedDesc(adaptor.source());
351     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
352     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
353         loc,
354         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
355                                    addressSpace),
356         underlyingRankedDesc);
357 
358     // Get pointer to offset field of memref<element_type> descriptor.
359     Type indexPtrTy = LLVM::LLVMPointerType::get(
360         getTypeConverter()->getIndexType(), addressSpace);
361     Value two = rewriter.create<LLVM::ConstantOp>(
362         loc, typeConverter->convertType(rewriter.getI32Type()),
363         rewriter.getI32IntegerAttr(2));
364     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
365         loc, indexPtrTy, scalarMemRefDescPtr,
366         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
367 
368     // The size value that we have to extract can be obtained using GEPop with
369     // `dimOp.index() + 1` index argument.
370     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
371         loc, createIndexConstant(rewriter, loc, 1), adaptor.index());
372     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
373                                                  ValueRange({idxPlusOne}));
374     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
375   }
376 
377   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
378     if (Optional<int64_t> idx = dimOp.getConstantIndex())
379       return idx;
380 
381     if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>())
382       return constantOp.getValue()
383           .cast<IntegerAttr>()
384           .getValue()
385           .getSExtValue();
386 
387     return llvm::None;
388   }
389 
390   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
391                                   OpAdaptor adaptor,
392                                   ConversionPatternRewriter &rewriter) const {
393     Location loc = dimOp.getLoc();
394 
395     // Take advantage if index is constant.
396     MemRefType memRefType = operandType.cast<MemRefType>();
397     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
398       int64_t i = index.getValue();
399       if (memRefType.isDynamicDim(i)) {
400         // extract dynamic size from the memref descriptor.
401         MemRefDescriptor descriptor(adaptor.source());
402         return descriptor.size(rewriter, loc, i);
403       }
404       // Use constant for static size.
405       int64_t dimSize = memRefType.getDimSize(i);
406       return createIndexConstant(rewriter, loc, dimSize);
407     }
408     Value index = adaptor.index();
409     int64_t rank = memRefType.getRank();
410     MemRefDescriptor memrefDescriptor(adaptor.source());
411     return memrefDescriptor.size(rewriter, loc, index, rank);
412   }
413 };
414 
415 /// Common base for load and store operations on MemRefs. Restricts the match
416 /// to supported MemRef types. Provides functionality to emit code accessing a
417 /// specific element of the underlying data buffer.
418 template <typename Derived>
419 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
420   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
421   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
422   using Base = LoadStoreOpLowering<Derived>;
423 
424   LogicalResult match(Derived op) const override {
425     MemRefType type = op.getMemRefType();
426     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
427   }
428 };
429 
430 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
431 /// retried until it succeeds in atomically storing a new value into memory.
432 ///
433 ///      +---------------------------------+
434 ///      |   <code before the AtomicRMWOp> |
435 ///      |   <compute initial %loaded>     |
436 ///      |   br loop(%loaded)              |
437 ///      +---------------------------------+
438 ///             |
439 ///  -------|   |
440 ///  |      v   v
441 ///  |   +--------------------------------+
442 ///  |   | loop(%loaded):                 |
443 ///  |   |   <body contents>              |
444 ///  |   |   %pair = cmpxchg              |
445 ///  |   |   %ok = %pair[0]               |
446 ///  |   |   %new = %pair[1]              |
447 ///  |   |   cond_br %ok, end, loop(%new) |
448 ///  |   +--------------------------------+
449 ///  |          |        |
450 ///  |-----------        |
451 ///                      v
452 ///      +--------------------------------+
453 ///      | end:                           |
454 ///      |   <code after the AtomicRMWOp> |
455 ///      +--------------------------------+
456 ///
457 struct GenericAtomicRMWOpLowering
458     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
459   using Base::Base;
460 
461   LogicalResult
462   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
463                   ConversionPatternRewriter &rewriter) const override {
464     auto loc = atomicOp.getLoc();
465     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
466 
467     // Split the block into initial, loop, and ending parts.
468     auto *initBlock = rewriter.getInsertionBlock();
469     auto *loopBlock = rewriter.createBlock(
470         initBlock->getParent(), std::next(Region::iterator(initBlock)),
471         valueType, loc);
472     auto *endBlock = rewriter.createBlock(
473         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
474 
475     // Operations range to be moved to `endBlock`.
476     auto opsToMoveStart = atomicOp->getIterator();
477     auto opsToMoveEnd = initBlock->back().getIterator();
478 
479     // Compute the loaded value and branch to the loop block.
480     rewriter.setInsertionPointToEnd(initBlock);
481     auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
482     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
483                                         adaptor.indices(), rewriter);
484     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
485     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
486 
487     // Prepare the body of the loop block.
488     rewriter.setInsertionPointToStart(loopBlock);
489 
490     // Clone the GenericAtomicRMWOp region and extract the result.
491     auto loopArgument = loopBlock->getArgument(0);
492     BlockAndValueMapping mapping;
493     mapping.map(atomicOp.getCurrentValue(), loopArgument);
494     Block &entryBlock = atomicOp.body().front();
495     for (auto &nestedOp : entryBlock.without_terminator()) {
496       Operation *clone = rewriter.clone(nestedOp, mapping);
497       mapping.map(nestedOp.getResults(), clone->getResults());
498     }
499     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
500 
501     // Prepare the epilog of the loop block.
502     // Append the cmpxchg op to the end of the loop block.
503     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
504     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
505     auto boolType = IntegerType::get(rewriter.getContext(), 1);
506     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
507                                                      {valueType, boolType});
508     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
509         loc, pairType, dataPtr, loopArgument, result, successOrdering,
510         failureOrdering);
511     // Extract the %new_loaded and %ok values from the pair.
512     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
513         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
514     Value ok = rewriter.create<LLVM::ExtractValueOp>(
515         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
516 
517     // Conditionally branch to the end or back to the loop depending on %ok.
518     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
519                                     loopBlock, newLoaded);
520 
521     rewriter.setInsertionPointToEnd(endBlock);
522     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
523                  std::next(opsToMoveEnd), rewriter);
524 
525     // The 'result' of the atomic_rmw op is the newly loaded value.
526     rewriter.replaceOp(atomicOp, {newLoaded});
527 
528     return success();
529   }
530 
531 private:
532   // Clones a segment of ops [start, end) and erases the original.
533   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
534                     Block::iterator start, Block::iterator end,
535                     ConversionPatternRewriter &rewriter) const {
536     BlockAndValueMapping mapping;
537     mapping.map(oldResult, newResult);
538     SmallVector<Operation *, 2> opsToErase;
539     for (auto it = start; it != end; ++it) {
540       rewriter.clone(*it, mapping);
541       opsToErase.push_back(&*it);
542     }
543     for (auto *it : opsToErase)
544       rewriter.eraseOp(it);
545   }
546 };
547 
548 /// Returns the LLVM type of the global variable given the memref type `type`.
549 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
550                                           LLVMTypeConverter &typeConverter) {
551   // LLVM type for a global memref will be a multi-dimension array. For
552   // declarations or uninitialized global memrefs, we can potentially flatten
553   // this to a 1D array. However, for memref.global's with an initial value,
554   // we do not intend to flatten the ElementsAttribute when going from std ->
555   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
556   Type elementType = typeConverter.convertType(type.getElementType());
557   Type arrayTy = elementType;
558   // Shape has the outermost dim at index 0, so need to walk it backwards
559   for (int64_t dim : llvm::reverse(type.getShape()))
560     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
561   return arrayTy;
562 }
563 
564 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
565 struct GlobalMemrefOpLowering
566     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
567   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
568 
569   LogicalResult
570   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
571                   ConversionPatternRewriter &rewriter) const override {
572     MemRefType type = global.type();
573     if (!isConvertibleAndHasIdentityMaps(type))
574       return failure();
575 
576     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
577 
578     LLVM::Linkage linkage =
579         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
580 
581     Attribute initialValue = nullptr;
582     if (!global.isExternal() && !global.isUninitialized()) {
583       auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
584       initialValue = elementsAttr;
585 
586       // For scalar memrefs, the global variable created is of the element type,
587       // so unpack the elements attribute to extract the value.
588       if (type.getRank() == 0)
589         initialValue = elementsAttr.getSplatValue<Attribute>();
590     }
591 
592     uint64_t alignment = global.alignment().getValueOr(0);
593 
594     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
595         global, arrayTy, global.constant(), linkage, global.sym_name(),
596         initialValue, alignment, type.getMemorySpaceAsInt());
597     if (!global.isExternal() && global.isUninitialized()) {
598       Block *blk = new Block();
599       newGlobal.getInitializerRegion().push_back(blk);
600       rewriter.setInsertionPointToStart(blk);
601       Value undef[] = {
602           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
603       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
604     }
605     return success();
606   }
607 };
608 
609 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
610 /// the first element stashed into the descriptor. This reuses
611 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
612 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
613   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
614       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
615                                 converter) {}
616 
617   /// Buffer "allocation" for memref.get_global op is getting the address of
618   /// the global variable referenced.
619   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
620                                           Location loc, Value sizeBytes,
621                                           Operation *op) const override {
622     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
623     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
624     unsigned memSpace = type.getMemorySpaceAsInt();
625 
626     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
627     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
628         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
629 
630     // Get the address of the first element in the array by creating a GEP with
631     // the address of the GV as the base, and (rank + 1) number of 0 indices.
632     Type elementType = typeConverter->convertType(type.getElementType());
633     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
634 
635     SmallVector<Value> operands;
636     operands.insert(operands.end(), type.getRank() + 1,
637                     createIndexConstant(rewriter, loc, 0));
638     auto gep =
639         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
640 
641     // We do not expect the memref obtained using `memref.get_global` to be
642     // ever deallocated. Set the allocated pointer to be known bad value to
643     // help debug if that ever happens.
644     auto intPtrType = getIntPtrType(memSpace);
645     Value deadBeefConst =
646         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
647     auto deadBeefPtr =
648         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
649 
650     // Both allocated and aligned pointers are same. We could potentially stash
651     // a nullptr for the allocated pointer since we do not expect any dealloc.
652     return std::make_tuple(deadBeefPtr, gep);
653   }
654 };
655 
656 // Load operation is lowered to obtaining a pointer to the indexed element
657 // and loading it.
658 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
659   using Base::Base;
660 
661   LogicalResult
662   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
663                   ConversionPatternRewriter &rewriter) const override {
664     auto type = loadOp.getMemRefType();
665 
666     Value dataPtr = getStridedElementPtr(
667         loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter);
668     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
669     return success();
670   }
671 };
672 
673 // Store operation is lowered to obtaining a pointer to the indexed element,
674 // and storing the given value to it.
675 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
676   using Base::Base;
677 
678   LogicalResult
679   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
680                   ConversionPatternRewriter &rewriter) const override {
681     auto type = op.getMemRefType();
682 
683     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(),
684                                          adaptor.indices(), rewriter);
685     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr);
686     return success();
687   }
688 };
689 
690 // The prefetch operation is lowered in a way similar to the load operation
691 // except that the llvm.prefetch operation is used for replacement.
692 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
693   using Base::Base;
694 
695   LogicalResult
696   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
697                   ConversionPatternRewriter &rewriter) const override {
698     auto type = prefetchOp.getMemRefType();
699     auto loc = prefetchOp.getLoc();
700 
701     Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(),
702                                          adaptor.indices(), rewriter);
703 
704     // Replace with llvm.prefetch.
705     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
706     auto isWrite = rewriter.create<LLVM::ConstantOp>(
707         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
708     auto localityHint = rewriter.create<LLVM::ConstantOp>(
709         loc, llvmI32Type,
710         rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
711     auto isData = rewriter.create<LLVM::ConstantOp>(
712         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
713 
714     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
715                                                 localityHint, isData);
716     return success();
717   }
718 };
719 
720 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
721   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
722 
723   LogicalResult
724   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
725                   ConversionPatternRewriter &rewriter) const override {
726     Location loc = op.getLoc();
727     Type operandType = op.memref().getType();
728     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
729       UnrankedMemRefDescriptor desc(adaptor.memref());
730       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
731       return success();
732     }
733     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
734       rewriter.replaceOp(
735           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
736       return success();
737     }
738     return failure();
739   }
740 };
741 
742 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
743   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
744 
745   LogicalResult match(memref::CastOp memRefCastOp) const override {
746     Type srcType = memRefCastOp.getOperand().getType();
747     Type dstType = memRefCastOp.getType();
748 
749     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
750     // used for type erasure. For now they must preserve underlying element type
751     // and require source and result type to have the same rank. Therefore,
752     // perform a sanity check that the underlying structs are the same. Once op
753     // semantics are relaxed we can revisit.
754     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
755       return success(typeConverter->convertType(srcType) ==
756                      typeConverter->convertType(dstType));
757 
758     // At least one of the operands is unranked type
759     assert(srcType.isa<UnrankedMemRefType>() ||
760            dstType.isa<UnrankedMemRefType>());
761 
762     // Unranked to unranked cast is disallowed
763     return !(srcType.isa<UnrankedMemRefType>() &&
764              dstType.isa<UnrankedMemRefType>())
765                ? success()
766                : failure();
767   }
768 
769   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
770                ConversionPatternRewriter &rewriter) const override {
771     auto srcType = memRefCastOp.getOperand().getType();
772     auto dstType = memRefCastOp.getType();
773     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
774     auto loc = memRefCastOp.getLoc();
775 
776     // For ranked/ranked case, just keep the original descriptor.
777     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
778       return rewriter.replaceOp(memRefCastOp, {adaptor.source()});
779 
780     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
781       // Casting ranked to unranked memref type
782       // Set the rank in the destination from the memref type
783       // Allocate space on the stack and copy the src memref descriptor
784       // Set the ptr in the destination to the stack space
785       auto srcMemRefType = srcType.cast<MemRefType>();
786       int64_t rank = srcMemRefType.getRank();
787       // ptr = AllocaOp sizeof(MemRefDescriptor)
788       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
789           loc, adaptor.source(), rewriter);
790       // voidptr = BitCastOp srcType* to void*
791       auto voidPtr =
792           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
793               .getResult();
794       // rank = ConstantOp srcRank
795       auto rankVal = rewriter.create<LLVM::ConstantOp>(
796           loc, typeConverter->convertType(rewriter.getIntegerType(64)),
797           rewriter.getI64IntegerAttr(rank));
798       // undef = UndefOp
799       UnrankedMemRefDescriptor memRefDesc =
800           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
801       // d1 = InsertValueOp undef, rank, 0
802       memRefDesc.setRank(rewriter, loc, rankVal);
803       // d2 = InsertValueOp d1, voidptr, 1
804       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
805       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
806 
807     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
808       // Casting from unranked type to ranked.
809       // The operation is assumed to be doing a correct cast. If the destination
810       // type mismatches the unranked the type, it is undefined behavior.
811       UnrankedMemRefDescriptor memRefDesc(adaptor.source());
812       // ptr = ExtractValueOp src, 1
813       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
814       // castPtr = BitCastOp i8* to structTy*
815       auto castPtr =
816           rewriter
817               .create<LLVM::BitcastOp>(
818                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
819               .getResult();
820       // struct = LoadOp castPtr
821       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
822       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
823     } else {
824       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
825     }
826   }
827 };
828 
829 /// Pattern to lower a `memref.copy` to llvm.
830 ///
831 /// For memrefs with identity layouts, the copy is lowered to the llvm
832 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
833 /// to the generic `MemrefCopyFn`.
834 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
835   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
836 
837   LogicalResult
838   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
839                           ConversionPatternRewriter &rewriter) const {
840     auto loc = op.getLoc();
841     auto srcType = op.source().getType().dyn_cast<MemRefType>();
842 
843     MemRefDescriptor srcDesc(adaptor.source());
844 
845     // Compute number of elements.
846     Value numElements = rewriter.create<LLVM::ConstantOp>(
847         loc, getIndexType(), rewriter.getIndexAttr(1));
848     for (int pos = 0; pos < srcType.getRank(); ++pos) {
849       auto size = srcDesc.size(rewriter, loc, pos);
850       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
851     }
852 
853     // Get element size.
854     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
855     // Compute total.
856     Value totalSize =
857         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
858 
859     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
860     MemRefDescriptor targetDesc(adaptor.target());
861     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
862     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
863         loc, typeConverter->convertType(rewriter.getI1Type()),
864         rewriter.getBoolAttr(false));
865     rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize,
866                                     isVolatile);
867     rewriter.eraseOp(op);
868 
869     return success();
870   }
871 
872   LogicalResult
873   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
874                              ConversionPatternRewriter &rewriter) const {
875     auto loc = op.getLoc();
876     auto srcType = op.source().getType().cast<BaseMemRefType>();
877     auto targetType = op.target().getType().cast<BaseMemRefType>();
878 
879     // First make sure we have an unranked memref descriptor representation.
880     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
881       auto rank = rewriter.create<LLVM::ConstantOp>(
882           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
883       auto *typeConverter = getTypeConverter();
884       auto ptr =
885           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
886       auto voidPtr =
887           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
888               .getResult();
889       auto unrankedType =
890           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
891       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
892                                             unrankedType,
893                                             ValueRange{rank, voidPtr});
894     };
895 
896     Value unrankedSource = srcType.hasRank()
897                                ? makeUnranked(adaptor.source(), srcType)
898                                : adaptor.source();
899     Value unrankedTarget = targetType.hasRank()
900                                ? makeUnranked(adaptor.target(), targetType)
901                                : adaptor.target();
902 
903     // Now promote the unranked descriptors to the stack.
904     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
905                                                  rewriter.getIndexAttr(1));
906     auto promote = [&](Value desc) {
907       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
908       auto allocated =
909           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
910       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
911       return allocated;
912     };
913 
914     auto sourcePtr = promote(unrankedSource);
915     auto targetPtr = promote(unrankedTarget);
916 
917     unsigned bitwidth = mlir::DataLayout::closest(op).getTypeSizeInBits(
918         srcType.getElementType());
919     auto elemSize = rewriter.create<LLVM::ConstantOp>(
920         loc, getIndexType(), rewriter.getIndexAttr(bitwidth / 8));
921     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
922         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
923     rewriter.create<LLVM::CallOp>(loc, copyFn,
924                                   ValueRange{elemSize, sourcePtr, targetPtr});
925     rewriter.eraseOp(op);
926 
927     return success();
928   }
929 
930   LogicalResult
931   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
932                   ConversionPatternRewriter &rewriter) const override {
933     auto srcType = op.source().getType().cast<BaseMemRefType>();
934     auto targetType = op.target().getType().cast<BaseMemRefType>();
935 
936     if (srcType.hasRank() &&
937         srcType.cast<MemRefType>().getLayout().isIdentity() &&
938         targetType.hasRank() &&
939         targetType.cast<MemRefType>().getLayout().isIdentity())
940       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
941 
942     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
943   }
944 };
945 
946 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
947 /// memref type. In unranked case, the fields are extracted from the underlying
948 /// ranked descriptor.
949 static void extractPointersAndOffset(Location loc,
950                                      ConversionPatternRewriter &rewriter,
951                                      LLVMTypeConverter &typeConverter,
952                                      Value originalOperand,
953                                      Value convertedOperand,
954                                      Value *allocatedPtr, Value *alignedPtr,
955                                      Value *offset = nullptr) {
956   Type operandType = originalOperand.getType();
957   if (operandType.isa<MemRefType>()) {
958     MemRefDescriptor desc(convertedOperand);
959     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
960     *alignedPtr = desc.alignedPtr(rewriter, loc);
961     if (offset != nullptr)
962       *offset = desc.offset(rewriter, loc);
963     return;
964   }
965 
966   unsigned memorySpace =
967       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
968   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
969   Type llvmElementType = typeConverter.convertType(elementType);
970   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
971       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
972 
973   // Extract pointer to the underlying ranked memref descriptor and cast it to
974   // ElemType**.
975   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
976   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
977 
978   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
979       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
980   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
981       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
982   if (offset != nullptr) {
983     *offset = UnrankedMemRefDescriptor::offset(
984         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
985   }
986 }
987 
988 struct MemRefReinterpretCastOpLowering
989     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
990   using ConvertOpToLLVMPattern<
991       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
992 
993   LogicalResult
994   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
995                   ConversionPatternRewriter &rewriter) const override {
996     Type srcType = castOp.source().getType();
997 
998     Value descriptor;
999     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1000                                                adaptor, &descriptor)))
1001       return failure();
1002     rewriter.replaceOp(castOp, {descriptor});
1003     return success();
1004   }
1005 
1006 private:
1007   LogicalResult convertSourceMemRefToDescriptor(
1008       ConversionPatternRewriter &rewriter, Type srcType,
1009       memref::ReinterpretCastOp castOp,
1010       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1011     MemRefType targetMemRefType =
1012         castOp.getResult().getType().cast<MemRefType>();
1013     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1014                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1015     if (!llvmTargetDescriptorTy)
1016       return failure();
1017 
1018     // Create descriptor.
1019     Location loc = castOp.getLoc();
1020     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1021 
1022     // Set allocated and aligned pointers.
1023     Value allocatedPtr, alignedPtr;
1024     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1025                              castOp.source(), adaptor.source(), &allocatedPtr,
1026                              &alignedPtr);
1027     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1028     desc.setAlignedPtr(rewriter, loc, alignedPtr);
1029 
1030     // Set offset.
1031     if (castOp.isDynamicOffset(0))
1032       desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
1033     else
1034       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1035 
1036     // Set sizes and strides.
1037     unsigned dynSizeId = 0;
1038     unsigned dynStrideId = 0;
1039     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1040       if (castOp.isDynamicSize(i))
1041         desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
1042       else
1043         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1044 
1045       if (castOp.isDynamicStride(i))
1046         desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
1047       else
1048         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1049     }
1050     *descriptor = desc;
1051     return success();
1052   }
1053 };
1054 
1055 struct MemRefReshapeOpLowering
1056     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1057   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1058 
1059   LogicalResult
1060   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1061                   ConversionPatternRewriter &rewriter) const override {
1062     Type srcType = reshapeOp.source().getType();
1063 
1064     Value descriptor;
1065     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1066                                                adaptor, &descriptor)))
1067       return failure();
1068     rewriter.replaceOp(reshapeOp, {descriptor});
1069     return success();
1070   }
1071 
1072 private:
1073   LogicalResult
1074   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1075                                   Type srcType, memref::ReshapeOp reshapeOp,
1076                                   memref::ReshapeOp::Adaptor adaptor,
1077                                   Value *descriptor) const {
1078     // Conversion for statically-known shape args is performed via
1079     // `memref_reinterpret_cast`.
1080     auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
1081     if (shapeMemRefType.hasStaticShape())
1082       return failure();
1083 
1084     // The shape is a rank-1 tensor with unknown length.
1085     Location loc = reshapeOp.getLoc();
1086     MemRefDescriptor shapeDesc(adaptor.shape());
1087     Value resultRank = shapeDesc.size(rewriter, loc, 0);
1088 
1089     // Extract address space and element type.
1090     auto targetType =
1091         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1092     unsigned addressSpace = targetType.getMemorySpaceAsInt();
1093     Type elementType = targetType.getElementType();
1094 
1095     // Create the unranked memref descriptor that holds the ranked one. The
1096     // inner descriptor is allocated on stack.
1097     auto targetDesc = UnrankedMemRefDescriptor::undef(
1098         rewriter, loc, typeConverter->convertType(targetType));
1099     targetDesc.setRank(rewriter, loc, resultRank);
1100     SmallVector<Value, 4> sizes;
1101     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1102                                            targetDesc, sizes);
1103     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1104         loc, getVoidPtrType(), sizes.front(), llvm::None);
1105     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1106 
1107     // Extract pointers and offset from the source memref.
1108     Value allocatedPtr, alignedPtr, offset;
1109     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1110                              reshapeOp.source(), adaptor.source(),
1111                              &allocatedPtr, &alignedPtr, &offset);
1112 
1113     // Set pointers and offset.
1114     Type llvmElementType = typeConverter->convertType(elementType);
1115     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1116         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1117     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1118                                               elementPtrPtrType, allocatedPtr);
1119     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1120                                             underlyingDescPtr,
1121                                             elementPtrPtrType, alignedPtr);
1122     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1123                                         underlyingDescPtr, elementPtrPtrType,
1124                                         offset);
1125 
1126     // Use the offset pointer as base for further addressing. Copy over the new
1127     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1128     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1129         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1130         elementPtrPtrType);
1131     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1132         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1133     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1134     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1135     Value resultRankMinusOne =
1136         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1137 
1138     Block *initBlock = rewriter.getInsertionBlock();
1139     Type indexType = getTypeConverter()->getIndexType();
1140     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1141 
1142     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1143                                             {indexType, indexType}, {loc, loc});
1144 
1145     // Move the remaining initBlock ops to condBlock.
1146     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1147     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1148 
1149     rewriter.setInsertionPointToEnd(initBlock);
1150     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1151                                 condBlock);
1152     rewriter.setInsertionPointToStart(condBlock);
1153     Value indexArg = condBlock->getArgument(0);
1154     Value strideArg = condBlock->getArgument(1);
1155 
1156     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1157     Value pred = rewriter.create<LLVM::ICmpOp>(
1158         loc, IntegerType::get(rewriter.getContext(), 1),
1159         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1160 
1161     Block *bodyBlock =
1162         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1163     rewriter.setInsertionPointToStart(bodyBlock);
1164 
1165     // Copy size from shape to descriptor.
1166     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1167     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1168         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1169     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1170     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1171                                       targetSizesBase, indexArg, size);
1172 
1173     // Write stride value and compute next one.
1174     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1175                                         targetStridesBase, indexArg, strideArg);
1176     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1177 
1178     // Decrement loop counter and branch back.
1179     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1180     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1181                                 condBlock);
1182 
1183     Block *remainder =
1184         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1185 
1186     // Hook up the cond exit to the remainder.
1187     rewriter.setInsertionPointToEnd(condBlock);
1188     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1189                                     llvm::None);
1190 
1191     // Reset position to beginning of new remainder block.
1192     rewriter.setInsertionPointToStart(remainder);
1193 
1194     *descriptor = targetDesc;
1195     return success();
1196   }
1197 };
1198 
1199 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1200 /// `Value`s.
1201 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1202                                       Type &llvmIndexType,
1203                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1204   return llvm::to_vector<4>(
1205       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1206         if (auto attr = value.dyn_cast<Attribute>())
1207           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1208         return value.get<Value>();
1209       }));
1210 }
1211 
1212 /// Compute a map that for a given dimension of the expanded type gives the
1213 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1214 /// the `reassocation` maps.
1215 static DenseMap<int64_t, int64_t>
1216 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1217   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1218   for (auto &en : enumerate(reassociation)) {
1219     for (auto dim : en.value())
1220       expandedDimToCollapsedDim[dim] = en.index();
1221   }
1222   return expandedDimToCollapsedDim;
1223 }
1224 
1225 static OpFoldResult
1226 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1227                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1228                          MemRefDescriptor &inDesc,
1229                          ArrayRef<int64_t> inStaticShape,
1230                          ArrayRef<ReassociationIndices> reassocation,
1231                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1232   int64_t outDimSize = outStaticShape[outDimIndex];
1233   if (!ShapedType::isDynamic(outDimSize))
1234     return b.getIndexAttr(outDimSize);
1235 
1236   // Calculate the multiplication of all the out dim sizes except the
1237   // current dim.
1238   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1239   int64_t otherDimSizesMul = 1;
1240   for (auto otherDimIndex : reassocation[inDimIndex]) {
1241     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1242       continue;
1243     int64_t otherDimSize = outStaticShape[otherDimIndex];
1244     assert(!ShapedType::isDynamic(otherDimSize) &&
1245            "single dimension cannot be expanded into multiple dynamic "
1246            "dimensions");
1247     otherDimSizesMul *= otherDimSize;
1248   }
1249 
1250   // outDimSize = inDimSize / otherOutDimSizesMul
1251   int64_t inDimSize = inStaticShape[inDimIndex];
1252   Value inDimSizeDynamic =
1253       ShapedType::isDynamic(inDimSize)
1254           ? inDesc.size(b, loc, inDimIndex)
1255           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1256                                        b.getIndexAttr(inDimSize));
1257   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1258       loc, inDimSizeDynamic,
1259       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1260                                  b.getIndexAttr(otherDimSizesMul)));
1261   return outDimSizeDynamic;
1262 }
1263 
1264 static OpFoldResult getCollapsedOutputDimSize(
1265     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1266     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1267     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1268   if (!ShapedType::isDynamic(outDimSize))
1269     return b.getIndexAttr(outDimSize);
1270 
1271   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1272   Value outDimSizeDynamic = c1;
1273   for (auto inDimIndex : reassocation[outDimIndex]) {
1274     int64_t inDimSize = inStaticShape[inDimIndex];
1275     Value inDimSizeDynamic =
1276         ShapedType::isDynamic(inDimSize)
1277             ? inDesc.size(b, loc, inDimIndex)
1278             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1279                                          b.getIndexAttr(inDimSize));
1280     outDimSizeDynamic =
1281         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1282   }
1283   return outDimSizeDynamic;
1284 }
1285 
1286 static SmallVector<OpFoldResult, 4>
1287 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1288                         ArrayRef<ReassociationIndices> reassocation,
1289                         ArrayRef<int64_t> inStaticShape,
1290                         MemRefDescriptor &inDesc,
1291                         ArrayRef<int64_t> outStaticShape) {
1292   return llvm::to_vector<4>(llvm::map_range(
1293       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1294         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1295                                          outStaticShape[outDimIndex],
1296                                          inStaticShape, inDesc, reassocation);
1297       }));
1298 }
1299 
1300 static SmallVector<OpFoldResult, 4>
1301 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1302                        ArrayRef<ReassociationIndices> reassocation,
1303                        ArrayRef<int64_t> inStaticShape,
1304                        MemRefDescriptor &inDesc,
1305                        ArrayRef<int64_t> outStaticShape) {
1306   DenseMap<int64_t, int64_t> outDimToInDimMap =
1307       getExpandedDimToCollapsedDimMap(reassocation);
1308   return llvm::to_vector<4>(llvm::map_range(
1309       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1310         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1311                                         outStaticShape, inDesc, inStaticShape,
1312                                         reassocation, outDimToInDimMap);
1313       }));
1314 }
1315 
1316 static SmallVector<Value>
1317 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1318                       ArrayRef<ReassociationIndices> reassocation,
1319                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1320                       ArrayRef<int64_t> outStaticShape) {
1321   return outStaticShape.size() < inStaticShape.size()
1322              ? getAsValues(b, loc, llvmIndexType,
1323                            getCollapsedOutputShape(b, loc, llvmIndexType,
1324                                                    reassocation, inStaticShape,
1325                                                    inDesc, outStaticShape))
1326              : getAsValues(b, loc, llvmIndexType,
1327                            getExpandedOutputShape(b, loc, llvmIndexType,
1328                                                   reassocation, inStaticShape,
1329                                                   inDesc, outStaticShape));
1330 }
1331 
1332 // ReshapeOp creates a new view descriptor of the proper rank.
1333 // For now, the only conversion supported is for target MemRef with static sizes
1334 // and strides.
1335 template <typename ReshapeOp>
1336 class ReassociatingReshapeOpConversion
1337     : public ConvertOpToLLVMPattern<ReshapeOp> {
1338 public:
1339   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1340   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1341 
1342   LogicalResult
1343   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1344                   ConversionPatternRewriter &rewriter) const override {
1345     MemRefType dstType = reshapeOp.getResultType();
1346     MemRefType srcType = reshapeOp.getSrcType();
1347 
1348     // The condition on the layouts can be ignored when all shapes are static.
1349     if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) {
1350       if (!srcType.getLayout().isIdentity() ||
1351           !dstType.getLayout().isIdentity()) {
1352         return rewriter.notifyMatchFailure(
1353             reshapeOp, "only empty layout map is supported");
1354       }
1355     }
1356 
1357     int64_t offset;
1358     SmallVector<int64_t, 4> strides;
1359     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1360       return rewriter.notifyMatchFailure(
1361           reshapeOp, "failed to get stride and offset exprs");
1362     }
1363 
1364     MemRefDescriptor srcDesc(adaptor.src());
1365     Location loc = reshapeOp->getLoc();
1366     auto dstDesc = MemRefDescriptor::undef(
1367         rewriter, loc, this->typeConverter->convertType(dstType));
1368     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1369     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1370     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1371 
1372     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1373     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1374     Type llvmIndexType =
1375         this->typeConverter->convertType(rewriter.getIndexType());
1376     SmallVector<Value> dstShape = getDynamicOutputShape(
1377         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1378         srcStaticShape, srcDesc, dstStaticShape);
1379     for (auto &en : llvm::enumerate(dstShape))
1380       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1381 
1382     auto isStaticStride = [](int64_t stride) {
1383       return !ShapedType::isDynamicStrideOrOffset(stride);
1384     };
1385     if (llvm::all_of(strides, isStaticStride)) {
1386       for (auto &en : llvm::enumerate(strides))
1387         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1388     } else {
1389       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1390                                                    rewriter.getIndexAttr(1));
1391       Value stride = c1;
1392       for (auto dimIndex :
1393            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1394         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1395         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1396       }
1397     }
1398     rewriter.replaceOp(reshapeOp, {dstDesc});
1399     return success();
1400   }
1401 };
1402 
1403 /// Conversion pattern that transforms a subview op into:
1404 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1405 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1406 ///      and stride.
1407 /// The subview op is replaced by the descriptor.
1408 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1409   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1410 
1411   LogicalResult
1412   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1413                   ConversionPatternRewriter &rewriter) const override {
1414     auto loc = subViewOp.getLoc();
1415 
1416     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
1417     auto sourceElementTy =
1418         typeConverter->convertType(sourceMemRefType.getElementType());
1419 
1420     auto viewMemRefType = subViewOp.getType();
1421     auto inferredType = memref::SubViewOp::inferResultType(
1422                             subViewOp.getSourceType(),
1423                             extractFromI64ArrayAttr(subViewOp.static_offsets()),
1424                             extractFromI64ArrayAttr(subViewOp.static_sizes()),
1425                             extractFromI64ArrayAttr(subViewOp.static_strides()))
1426                             .cast<MemRefType>();
1427     auto targetElementTy =
1428         typeConverter->convertType(viewMemRefType.getElementType());
1429     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1430     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1431         !LLVM::isCompatibleType(sourceElementTy) ||
1432         !LLVM::isCompatibleType(targetElementTy) ||
1433         !LLVM::isCompatibleType(targetDescTy))
1434       return failure();
1435 
1436     // Extract the offset and strides from the type.
1437     int64_t offset;
1438     SmallVector<int64_t, 4> strides;
1439     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1440     if (failed(successStrides))
1441       return failure();
1442 
1443     // Create the descriptor.
1444     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1445       return failure();
1446     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1447     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1448 
1449     // Copy the buffer pointer from the old descriptor to the new one.
1450     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1451     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1452         loc,
1453         LLVM::LLVMPointerType::get(targetElementTy,
1454                                    viewMemRefType.getMemorySpaceAsInt()),
1455         extracted);
1456     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1457 
1458     // Copy the aligned pointer from the old descriptor to the new one.
1459     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1460     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1461         loc,
1462         LLVM::LLVMPointerType::get(targetElementTy,
1463                                    viewMemRefType.getMemorySpaceAsInt()),
1464         extracted);
1465     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1466 
1467     size_t inferredShapeRank = inferredType.getRank();
1468     size_t resultShapeRank = viewMemRefType.getRank();
1469 
1470     // Extract strides needed to compute offset.
1471     SmallVector<Value, 4> strideValues;
1472     strideValues.reserve(inferredShapeRank);
1473     for (unsigned i = 0; i < inferredShapeRank; ++i)
1474       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1475 
1476     // Offset.
1477     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1478     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1479       targetMemRef.setConstantOffset(rewriter, loc, offset);
1480     } else {
1481       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1482       // `inferredShapeRank` may be larger than the number of offset operands
1483       // because of trailing semantics. In this case, the offset is guaranteed
1484       // to be interpreted as 0 and we can just skip the extra dimensions.
1485       for (unsigned i = 0, e = std::min(inferredShapeRank,
1486                                         subViewOp.getMixedOffsets().size());
1487            i < e; ++i) {
1488         Value offset =
1489             // TODO: need OpFoldResult ODS adaptor to clean this up.
1490             subViewOp.isDynamicOffset(i)
1491                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1492                 : rewriter.create<LLVM::ConstantOp>(
1493                       loc, llvmIndexType,
1494                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1495         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1496         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1497       }
1498       targetMemRef.setOffset(rewriter, loc, baseOffset);
1499     }
1500 
1501     // Update sizes and strides.
1502     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1503     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1504     assert(mixedSizes.size() == mixedStrides.size() &&
1505            "expected sizes and strides of equal length");
1506     llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
1507     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1508          i >= 0 && j >= 0; --i) {
1509       if (unusedDims.contains(i))
1510         continue;
1511 
1512       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1513       // In this case, the size is guaranteed to be interpreted as Dim and the
1514       // stride as 1.
1515       Value size, stride;
1516       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1517         // If the static size is available, use it directly. This is similar to
1518         // the folding of dim(constant-op) but removes the need for dim to be
1519         // aware of LLVM constants and for this pass to be aware of std
1520         // constants.
1521         int64_t staticSize =
1522             subViewOp.source().getType().cast<MemRefType>().getShape()[i];
1523         if (staticSize != ShapedType::kDynamicSize) {
1524           size = rewriter.create<LLVM::ConstantOp>(
1525               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1526         } else {
1527           Value pos = rewriter.create<LLVM::ConstantOp>(
1528               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1529           Value dim =
1530               rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos);
1531           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1532               loc, llvmIndexType, dim);
1533           size = cast.getResult(0);
1534         }
1535         stride = rewriter.create<LLVM::ConstantOp>(
1536             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1537       } else {
1538         // TODO: need OpFoldResult ODS adaptor to clean this up.
1539         size =
1540             subViewOp.isDynamicSize(i)
1541                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1542                 : rewriter.create<LLVM::ConstantOp>(
1543                       loc, llvmIndexType,
1544                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1545         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1546           stride = rewriter.create<LLVM::ConstantOp>(
1547               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1548         } else {
1549           stride =
1550               subViewOp.isDynamicStride(i)
1551                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1552                   : rewriter.create<LLVM::ConstantOp>(
1553                         loc, llvmIndexType,
1554                         rewriter.getI64IntegerAttr(
1555                             subViewOp.getStaticStride(i)));
1556           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1557         }
1558       }
1559       targetMemRef.setSize(rewriter, loc, j, size);
1560       targetMemRef.setStride(rewriter, loc, j, stride);
1561       j--;
1562     }
1563 
1564     rewriter.replaceOp(subViewOp, {targetMemRef});
1565     return success();
1566   }
1567 };
1568 
1569 /// Conversion pattern that transforms a transpose op into:
1570 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1571 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1572 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1573 ///      and stride. Size and stride are permutations of the original values.
1574 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1575 /// The transpose op is replaced by the alloca'ed pointer.
1576 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1577 public:
1578   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1579 
1580   LogicalResult
1581   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1582                   ConversionPatternRewriter &rewriter) const override {
1583     auto loc = transposeOp.getLoc();
1584     MemRefDescriptor viewMemRef(adaptor.in());
1585 
1586     // No permutation, early exit.
1587     if (transposeOp.permutation().isIdentity())
1588       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1589 
1590     auto targetMemRef = MemRefDescriptor::undef(
1591         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1592 
1593     // Copy the base and aligned pointers from the old descriptor to the new
1594     // one.
1595     targetMemRef.setAllocatedPtr(rewriter, loc,
1596                                  viewMemRef.allocatedPtr(rewriter, loc));
1597     targetMemRef.setAlignedPtr(rewriter, loc,
1598                                viewMemRef.alignedPtr(rewriter, loc));
1599 
1600     // Copy the offset pointer from the old descriptor to the new one.
1601     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1602 
1603     // Iterate over the dimensions and apply size/stride permutation.
1604     for (const auto &en :
1605          llvm::enumerate(transposeOp.permutation().getResults())) {
1606       int sourcePos = en.index();
1607       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1608       targetMemRef.setSize(rewriter, loc, targetPos,
1609                            viewMemRef.size(rewriter, loc, sourcePos));
1610       targetMemRef.setStride(rewriter, loc, targetPos,
1611                              viewMemRef.stride(rewriter, loc, sourcePos));
1612     }
1613 
1614     rewriter.replaceOp(transposeOp, {targetMemRef});
1615     return success();
1616   }
1617 };
1618 
1619 /// Conversion pattern that transforms an op into:
1620 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1621 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1622 ///      and stride.
1623 /// The view op is replaced by the descriptor.
1624 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1625   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1626 
1627   // Build and return the value for the idx^th shape dimension, either by
1628   // returning the constant shape dimension or counting the proper dynamic size.
1629   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1630                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1631                 unsigned idx) const {
1632     assert(idx < shape.size());
1633     if (!ShapedType::isDynamic(shape[idx]))
1634       return createIndexConstant(rewriter, loc, shape[idx]);
1635     // Count the number of dynamic dims in range [0, idx]
1636     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1637       return ShapedType::isDynamic(v);
1638     });
1639     return dynamicSizes[nDynamic];
1640   }
1641 
1642   // Build and return the idx^th stride, either by returning the constant stride
1643   // or by computing the dynamic stride from the current `runningStride` and
1644   // `nextSize`. The caller should keep a running stride and update it with the
1645   // result returned by this function.
1646   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1647                   ArrayRef<int64_t> strides, Value nextSize,
1648                   Value runningStride, unsigned idx) const {
1649     assert(idx < strides.size());
1650     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1651       return createIndexConstant(rewriter, loc, strides[idx]);
1652     if (nextSize)
1653       return runningStride
1654                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1655                  : nextSize;
1656     assert(!runningStride);
1657     return createIndexConstant(rewriter, loc, 1);
1658   }
1659 
1660   LogicalResult
1661   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1662                   ConversionPatternRewriter &rewriter) const override {
1663     auto loc = viewOp.getLoc();
1664 
1665     auto viewMemRefType = viewOp.getType();
1666     auto targetElementTy =
1667         typeConverter->convertType(viewMemRefType.getElementType());
1668     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1669     if (!targetDescTy || !targetElementTy ||
1670         !LLVM::isCompatibleType(targetElementTy) ||
1671         !LLVM::isCompatibleType(targetDescTy))
1672       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1673              failure();
1674 
1675     int64_t offset;
1676     SmallVector<int64_t, 4> strides;
1677     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1678     if (failed(successStrides))
1679       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1680     assert(offset == 0 && "expected offset to be 0");
1681 
1682     // Create the descriptor.
1683     MemRefDescriptor sourceMemRef(adaptor.source());
1684     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1685 
1686     // Field 1: Copy the allocated pointer, used for malloc/free.
1687     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1688     auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
1689     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1690         loc,
1691         LLVM::LLVMPointerType::get(targetElementTy,
1692                                    srcMemRefType.getMemorySpaceAsInt()),
1693         allocatedPtr);
1694     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1695 
1696     // Field 2: Copy the actual aligned pointer to payload.
1697     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1698     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
1699                                               alignedPtr, adaptor.byte_shift());
1700     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1701         loc,
1702         LLVM::LLVMPointerType::get(targetElementTy,
1703                                    srcMemRefType.getMemorySpaceAsInt()),
1704         alignedPtr);
1705     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1706 
1707     // Field 3: The offset in the resulting type must be 0. This is because of
1708     // the type change: an offset on srcType* may not be expressible as an
1709     // offset on dstType*.
1710     targetMemRef.setOffset(rewriter, loc,
1711                            createIndexConstant(rewriter, loc, offset));
1712 
1713     // Early exit for 0-D corner case.
1714     if (viewMemRefType.getRank() == 0)
1715       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1716 
1717     // Fields 4 and 5: Update sizes and strides.
1718     if (strides.back() != 1)
1719       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1720              failure();
1721     Value stride = nullptr, nextSize = nullptr;
1722     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1723       // Update size.
1724       Value size =
1725           getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
1726       targetMemRef.setSize(rewriter, loc, i, size);
1727       // Update stride.
1728       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1729       targetMemRef.setStride(rewriter, loc, i, stride);
1730       nextSize = size;
1731     }
1732 
1733     rewriter.replaceOp(viewOp, {targetMemRef});
1734     return success();
1735   }
1736 };
1737 
1738 //===----------------------------------------------------------------------===//
1739 // AtomicRMWOpLowering
1740 //===----------------------------------------------------------------------===//
1741 
1742 /// Try to match the kind of a std.atomic_rmw to determine whether to use a
1743 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1744 static Optional<LLVM::AtomicBinOp>
1745 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1746   switch (atomicOp.kind()) {
1747   case arith::AtomicRMWKind::addf:
1748     return LLVM::AtomicBinOp::fadd;
1749   case arith::AtomicRMWKind::addi:
1750     return LLVM::AtomicBinOp::add;
1751   case arith::AtomicRMWKind::assign:
1752     return LLVM::AtomicBinOp::xchg;
1753   case arith::AtomicRMWKind::maxs:
1754     return LLVM::AtomicBinOp::max;
1755   case arith::AtomicRMWKind::maxu:
1756     return LLVM::AtomicBinOp::umax;
1757   case arith::AtomicRMWKind::mins:
1758     return LLVM::AtomicBinOp::min;
1759   case arith::AtomicRMWKind::minu:
1760     return LLVM::AtomicBinOp::umin;
1761   case arith::AtomicRMWKind::ori:
1762     return LLVM::AtomicBinOp::_or;
1763   case arith::AtomicRMWKind::andi:
1764     return LLVM::AtomicBinOp::_and;
1765   default:
1766     return llvm::None;
1767   }
1768   llvm_unreachable("Invalid AtomicRMWKind");
1769 }
1770 
1771 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1772   using Base::Base;
1773 
1774   LogicalResult
1775   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1776                   ConversionPatternRewriter &rewriter) const override {
1777     if (failed(match(atomicOp)))
1778       return failure();
1779     auto maybeKind = matchSimpleAtomicOp(atomicOp);
1780     if (!maybeKind)
1781       return failure();
1782     auto resultType = adaptor.value().getType();
1783     auto memRefType = atomicOp.getMemRefType();
1784     auto dataPtr =
1785         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
1786                              adaptor.indices(), rewriter);
1787     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1788         atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
1789         LLVM::AtomicOrdering::acq_rel);
1790     return success();
1791   }
1792 };
1793 
1794 } // namespace
1795 
1796 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
1797                                                   RewritePatternSet &patterns) {
1798   // clang-format off
1799   patterns.add<
1800       AllocaOpLowering,
1801       AllocaScopeOpLowering,
1802       AtomicRMWOpLowering,
1803       AssumeAlignmentOpLowering,
1804       DimOpLowering,
1805       GenericAtomicRMWOpLowering,
1806       GlobalMemrefOpLowering,
1807       GetGlobalMemrefOpLowering,
1808       LoadOpLowering,
1809       MemRefCastOpLowering,
1810       MemRefCopyOpLowering,
1811       MemRefReinterpretCastOpLowering,
1812       MemRefReshapeOpLowering,
1813       PrefetchOpLowering,
1814       RankOpLowering,
1815       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1816       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1817       StoreOpLowering,
1818       SubViewOpLowering,
1819       TransposeOpLowering,
1820       ViewOpLowering>(converter);
1821   // clang-format on
1822   auto allocLowering = converter.getOptions().allocLowering;
1823   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1824     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1825   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1826     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1827 }
1828 
1829 namespace {
1830 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
1831   MemRefToLLVMPass() = default;
1832 
1833   void runOnOperation() override {
1834     Operation *op = getOperation();
1835     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1836     LowerToLLVMOptions options(&getContext(),
1837                                dataLayoutAnalysis.getAtOrAbove(op));
1838     options.allocLowering =
1839         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1840                          : LowerToLLVMOptions::AllocLowering::Malloc);
1841     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1842       options.overrideIndexBitwidth(indexBitwidth);
1843 
1844     LLVMTypeConverter typeConverter(&getContext(), options,
1845                                     &dataLayoutAnalysis);
1846     RewritePatternSet patterns(&getContext());
1847     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
1848     LLVMConversionTarget target(getContext());
1849     target.addLegalOp<FuncOp>();
1850     if (failed(applyPartialConversion(op, target, std::move(patterns))))
1851       signalPassFailure();
1852   }
1853 };
1854 } // namespace
1855 
1856 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
1857   return std::make_unique<MemRefToLLVMPass>();
1858 }
1859