1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/BlockAndValueMapping.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
29   AllocOpLowering(LLVMTypeConverter &converter)
30       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
31                                 converter) {}
32 
33   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
34                                           Location loc, Value sizeBytes,
35                                           Operation *op) const override {
36     // Heap allocations.
37     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
38     MemRefType memRefType = allocOp.getType();
39 
40     Value alignment;
41     if (auto alignmentAttr = allocOp.alignment()) {
42       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
43     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
44       // In the case where no alignment is specified, we may want to override
45       // `malloc's` behavior. `malloc` typically aligns at the size of the
46       // biggest scalar on a target HW. For non-scalars, use the natural
47       // alignment of the LLVM type given by the LLVM DataLayout.
48       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
49     }
50 
51     if (alignment) {
52       // Adjust the allocation size to consider alignment.
53       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
54     }
55 
56     // Allocate the underlying buffer and store a pointer to it in the MemRef
57     // descriptor.
58     Type elementPtrType = this->getElementPtrType(memRefType);
59     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
60         allocOp->getParentOfType<ModuleOp>(), getIndexType());
61     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
62                                   getVoidPtrType());
63     Value allocatedPtr =
64         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
65 
66     Value alignedPtr = allocatedPtr;
67     if (alignment) {
68       // Compute the aligned type pointer.
69       Value allocatedInt =
70           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
71       Value alignmentInt =
72           createAligned(rewriter, loc, allocatedInt, alignment);
73       alignedPtr =
74           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
75     }
76 
77     return std::make_tuple(allocatedPtr, alignedPtr);
78   }
79 };
80 
81 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
82   AlignedAllocOpLowering(LLVMTypeConverter &converter)
83       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
84                                 converter) {}
85 
86   /// Returns the memref's element size in bytes using the data layout active at
87   /// `op`.
88   // TODO: there are other places where this is used. Expose publicly?
89   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
90     const DataLayout *layout = &defaultLayout;
91     if (const DataLayoutAnalysis *analysis =
92             getTypeConverter()->getDataLayoutAnalysis()) {
93       layout = &analysis->getAbove(op);
94     }
95     Type elementType = memRefType.getElementType();
96     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
97       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
98                                                          *layout);
99     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
100       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
101           memRefElementType, *layout);
102     return layout->getTypeSize(elementType);
103   }
104 
105   /// Returns true if the memref size in bytes is known to be a multiple of
106   /// factor assuming the data layout active at `op`.
107   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
108                               Operation *op) const {
109     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
110     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
111       if (ShapedType::isDynamic(type.getDimSize(i)))
112         continue;
113       sizeDivisor = sizeDivisor * type.getDimSize(i);
114     }
115     return sizeDivisor % factor == 0;
116   }
117 
118   /// Returns the alignment to be used for the allocation call itself.
119   /// aligned_alloc requires the allocation size to be a power of two, and the
120   /// allocation size to be a multiple of alignment,
121   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
122     if (Optional<uint64_t> alignment = allocOp.alignment())
123       return *alignment;
124 
125     // Whenever we don't have alignment set, we will use an alignment
126     // consistent with the element type; since the allocation size has to be a
127     // power of two, we will bump to the next power of two if it already isn't.
128     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
129     return std::max(kMinAlignedAllocAlignment,
130                     llvm::PowerOf2Ceil(eltSizeBytes));
131   }
132 
133   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
134                                           Location loc, Value sizeBytes,
135                                           Operation *op) const override {
136     // Heap allocations.
137     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
138     MemRefType memRefType = allocOp.getType();
139     int64_t alignment = getAllocationAlignment(allocOp);
140     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
141 
142     // aligned_alloc requires size to be a multiple of alignment; we will pad
143     // the size to the next multiple if necessary.
144     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
145       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
146 
147     Type elementPtrType = this->getElementPtrType(memRefType);
148     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
149         allocOp->getParentOfType<ModuleOp>(), getIndexType());
150     auto results =
151         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
152                        getVoidPtrType());
153     Value allocatedPtr =
154         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
155 
156     return std::make_tuple(allocatedPtr, allocatedPtr);
157   }
158 
159   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
160   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
161 
162   /// Default layout to use in absence of the corresponding analysis.
163   DataLayout defaultLayout;
164 };
165 
166 // Out of line definition, required till C++17.
167 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
168 
169 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
170   AllocaOpLowering(LLVMTypeConverter &converter)
171       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
172                                 converter) {}
173 
174   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
175   /// is set to null for stack allocations. `accessAlignment` is set if
176   /// alignment is needed post allocation (for eg. in conjunction with malloc).
177   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
178                                           Location loc, Value sizeBytes,
179                                           Operation *op) const override {
180 
181     // With alloca, one gets a pointer to the element type right away.
182     // For stack allocations.
183     auto allocaOp = cast<memref::AllocaOp>(op);
184     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
185 
186     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
187         loc, elementPtrType, sizeBytes,
188         allocaOp.alignment() ? *allocaOp.alignment() : 0);
189 
190     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
191   }
192 };
193 
194 struct AllocaScopeOpLowering
195     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
196   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
197 
198   LogicalResult
199   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
200                   ConversionPatternRewriter &rewriter) const override {
201     OpBuilder::InsertionGuard guard(rewriter);
202     Location loc = allocaScopeOp.getLoc();
203 
204     // Split the current block before the AllocaScopeOp to create the inlining
205     // point.
206     auto *currentBlock = rewriter.getInsertionBlock();
207     auto *remainingOpsBlock =
208         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
209     Block *continueBlock;
210     if (allocaScopeOp.getNumResults() == 0) {
211       continueBlock = remainingOpsBlock;
212     } else {
213       continueBlock = rewriter.createBlock(
214           remainingOpsBlock, allocaScopeOp.getResultTypes(),
215           SmallVector<Location>(allocaScopeOp->getNumResults(),
216                                 allocaScopeOp.getLoc()));
217       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
218     }
219 
220     // Inline body region.
221     Block *beforeBody = &allocaScopeOp.bodyRegion().front();
222     Block *afterBody = &allocaScopeOp.bodyRegion().back();
223     rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
224 
225     // Save stack and then branch into the body of the region.
226     rewriter.setInsertionPointToEnd(currentBlock);
227     auto stackSaveOp =
228         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
229     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
230 
231     // Replace the alloca_scope return with a branch that jumps out of the body.
232     // Stack restore before leaving the body region.
233     rewriter.setInsertionPointToEnd(afterBody);
234     auto returnOp =
235         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
236     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
237         returnOp, returnOp.results(), continueBlock);
238 
239     // Insert stack restore before jumping out the body of the region.
240     rewriter.setInsertionPoint(branchOp);
241     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
242 
243     // Replace the op with values return from the body region.
244     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
245 
246     return success();
247   }
248 };
249 
250 struct AssumeAlignmentOpLowering
251     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
252   using ConvertOpToLLVMPattern<
253       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
254 
255   LogicalResult
256   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
257                   ConversionPatternRewriter &rewriter) const override {
258     Value memref = adaptor.memref();
259     unsigned alignment = op.alignment();
260     auto loc = op.getLoc();
261 
262     MemRefDescriptor memRefDescriptor(memref);
263     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
264 
265     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
266     // the asserted memref.alignedPtr isn't used anywhere else, as the real
267     // users like load/store/views always re-extract memref.alignedPtr as they
268     // get lowered.
269     //
270     // This relies on LLVM's CSE optimization (potentially after SROA), since
271     // after CSE all memref.alignedPtr instances get de-duplicated into the same
272     // pointer SSA value.
273     auto intPtrType =
274         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
275     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
276     Value mask =
277         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
278     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
279     rewriter.create<LLVM::AssumeOp>(
280         loc, rewriter.create<LLVM::ICmpOp>(
281                  loc, LLVM::ICmpPredicate::eq,
282                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
283 
284     rewriter.eraseOp(op);
285     return success();
286   }
287 };
288 
289 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
290 // The memref descriptor being an SSA value, there is no need to clean it up
291 // in any way.
292 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
293   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
294 
295   explicit DeallocOpLowering(LLVMTypeConverter &converter)
296       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
297 
298   LogicalResult
299   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
300                   ConversionPatternRewriter &rewriter) const override {
301     // Insert the `free` declaration if it is not already present.
302     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
303     MemRefDescriptor memref(adaptor.memref());
304     Value casted = rewriter.create<LLVM::BitcastOp>(
305         op.getLoc(), getVoidPtrType(),
306         memref.allocatedPtr(rewriter, op.getLoc()));
307     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
308         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
309     return success();
310   }
311 };
312 
313 // A `dim` is converted to a constant for static sizes and to an access to the
314 // size stored in the memref descriptor for dynamic sizes.
315 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
316   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
317 
318   LogicalResult
319   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
320                   ConversionPatternRewriter &rewriter) const override {
321     Type operandType = dimOp.source().getType();
322     if (operandType.isa<UnrankedMemRefType>()) {
323       rewriter.replaceOp(
324           dimOp, {extractSizeOfUnrankedMemRef(
325                      operandType, dimOp, adaptor.getOperands(), rewriter)});
326 
327       return success();
328     }
329     if (operandType.isa<MemRefType>()) {
330       rewriter.replaceOp(
331           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
332                                             adaptor.getOperands(), rewriter)});
333       return success();
334     }
335     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
336   }
337 
338 private:
339   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
340                                     OpAdaptor adaptor,
341                                     ConversionPatternRewriter &rewriter) const {
342     Location loc = dimOp.getLoc();
343 
344     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
345     auto scalarMemRefType =
346         MemRefType::get({}, unrankedMemRefType.getElementType());
347     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
348 
349     // Extract pointer to the underlying ranked descriptor and bitcast it to a
350     // memref<element_type> descriptor pointer to minimize the number of GEP
351     // operations.
352     UnrankedMemRefDescriptor unrankedDesc(adaptor.source());
353     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
354     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
355         loc,
356         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
357                                    addressSpace),
358         underlyingRankedDesc);
359 
360     // Get pointer to offset field of memref<element_type> descriptor.
361     Type indexPtrTy = LLVM::LLVMPointerType::get(
362         getTypeConverter()->getIndexType(), addressSpace);
363     Value two = rewriter.create<LLVM::ConstantOp>(
364         loc, typeConverter->convertType(rewriter.getI32Type()),
365         rewriter.getI32IntegerAttr(2));
366     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
367         loc, indexPtrTy, scalarMemRefDescPtr,
368         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
369 
370     // The size value that we have to extract can be obtained using GEPop with
371     // `dimOp.index() + 1` index argument.
372     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
373         loc, createIndexConstant(rewriter, loc, 1), adaptor.index());
374     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
375                                                  ValueRange({idxPlusOne}));
376     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
377   }
378 
379   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
380     if (Optional<int64_t> idx = dimOp.getConstantIndex())
381       return idx;
382 
383     if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>())
384       return constantOp.getValue()
385           .cast<IntegerAttr>()
386           .getValue()
387           .getSExtValue();
388 
389     return llvm::None;
390   }
391 
392   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
393                                   OpAdaptor adaptor,
394                                   ConversionPatternRewriter &rewriter) const {
395     Location loc = dimOp.getLoc();
396 
397     // Take advantage if index is constant.
398     MemRefType memRefType = operandType.cast<MemRefType>();
399     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
400       int64_t i = index.getValue();
401       if (memRefType.isDynamicDim(i)) {
402         // extract dynamic size from the memref descriptor.
403         MemRefDescriptor descriptor(adaptor.source());
404         return descriptor.size(rewriter, loc, i);
405       }
406       // Use constant for static size.
407       int64_t dimSize = memRefType.getDimSize(i);
408       return createIndexConstant(rewriter, loc, dimSize);
409     }
410     Value index = adaptor.index();
411     int64_t rank = memRefType.getRank();
412     MemRefDescriptor memrefDescriptor(adaptor.source());
413     return memrefDescriptor.size(rewriter, loc, index, rank);
414   }
415 };
416 
417 /// Common base for load and store operations on MemRefs. Restricts the match
418 /// to supported MemRef types. Provides functionality to emit code accessing a
419 /// specific element of the underlying data buffer.
420 template <typename Derived>
421 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
422   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
423   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
424   using Base = LoadStoreOpLowering<Derived>;
425 
426   LogicalResult match(Derived op) const override {
427     MemRefType type = op.getMemRefType();
428     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
429   }
430 };
431 
432 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
433 /// retried until it succeeds in atomically storing a new value into memory.
434 ///
435 ///      +---------------------------------+
436 ///      |   <code before the AtomicRMWOp> |
437 ///      |   <compute initial %loaded>     |
438 ///      |   cf.br loop(%loaded)              |
439 ///      +---------------------------------+
440 ///             |
441 ///  -------|   |
442 ///  |      v   v
443 ///  |   +--------------------------------+
444 ///  |   | loop(%loaded):                 |
445 ///  |   |   <body contents>              |
446 ///  |   |   %pair = cmpxchg              |
447 ///  |   |   %ok = %pair[0]               |
448 ///  |   |   %new = %pair[1]              |
449 ///  |   |   cf.cond_br %ok, end, loop(%new) |
450 ///  |   +--------------------------------+
451 ///  |          |        |
452 ///  |-----------        |
453 ///                      v
454 ///      +--------------------------------+
455 ///      | end:                           |
456 ///      |   <code after the AtomicRMWOp> |
457 ///      +--------------------------------+
458 ///
459 struct GenericAtomicRMWOpLowering
460     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
461   using Base::Base;
462 
463   LogicalResult
464   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
465                   ConversionPatternRewriter &rewriter) const override {
466     auto loc = atomicOp.getLoc();
467     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
468 
469     // Split the block into initial, loop, and ending parts.
470     auto *initBlock = rewriter.getInsertionBlock();
471     auto *loopBlock = rewriter.createBlock(
472         initBlock->getParent(), std::next(Region::iterator(initBlock)),
473         valueType, loc);
474     auto *endBlock = rewriter.createBlock(
475         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
476 
477     // Operations range to be moved to `endBlock`.
478     auto opsToMoveStart = atomicOp->getIterator();
479     auto opsToMoveEnd = initBlock->back().getIterator();
480 
481     // Compute the loaded value and branch to the loop block.
482     rewriter.setInsertionPointToEnd(initBlock);
483     auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
484     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
485                                         adaptor.indices(), rewriter);
486     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
487     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
488 
489     // Prepare the body of the loop block.
490     rewriter.setInsertionPointToStart(loopBlock);
491 
492     // Clone the GenericAtomicRMWOp region and extract the result.
493     auto loopArgument = loopBlock->getArgument(0);
494     BlockAndValueMapping mapping;
495     mapping.map(atomicOp.getCurrentValue(), loopArgument);
496     Block &entryBlock = atomicOp.body().front();
497     for (auto &nestedOp : entryBlock.without_terminator()) {
498       Operation *clone = rewriter.clone(nestedOp, mapping);
499       mapping.map(nestedOp.getResults(), clone->getResults());
500     }
501     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
502 
503     // Prepare the epilog of the loop block.
504     // Append the cmpxchg op to the end of the loop block.
505     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
506     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
507     auto boolType = IntegerType::get(rewriter.getContext(), 1);
508     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
509                                                      {valueType, boolType});
510     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
511         loc, pairType, dataPtr, loopArgument, result, successOrdering,
512         failureOrdering);
513     // Extract the %new_loaded and %ok values from the pair.
514     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
515         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
516     Value ok = rewriter.create<LLVM::ExtractValueOp>(
517         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
518 
519     // Conditionally branch to the end or back to the loop depending on %ok.
520     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
521                                     loopBlock, newLoaded);
522 
523     rewriter.setInsertionPointToEnd(endBlock);
524     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
525                  std::next(opsToMoveEnd), rewriter);
526 
527     // The 'result' of the atomic_rmw op is the newly loaded value.
528     rewriter.replaceOp(atomicOp, {newLoaded});
529 
530     return success();
531   }
532 
533 private:
534   // Clones a segment of ops [start, end) and erases the original.
535   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
536                     Block::iterator start, Block::iterator end,
537                     ConversionPatternRewriter &rewriter) const {
538     BlockAndValueMapping mapping;
539     mapping.map(oldResult, newResult);
540     SmallVector<Operation *, 2> opsToErase;
541     for (auto it = start; it != end; ++it) {
542       rewriter.clone(*it, mapping);
543       opsToErase.push_back(&*it);
544     }
545     for (auto *it : opsToErase)
546       rewriter.eraseOp(it);
547   }
548 };
549 
550 /// Returns the LLVM type of the global variable given the memref type `type`.
551 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
552                                           LLVMTypeConverter &typeConverter) {
553   // LLVM type for a global memref will be a multi-dimension array. For
554   // declarations or uninitialized global memrefs, we can potentially flatten
555   // this to a 1D array. However, for memref.global's with an initial value,
556   // we do not intend to flatten the ElementsAttribute when going from std ->
557   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
558   Type elementType = typeConverter.convertType(type.getElementType());
559   Type arrayTy = elementType;
560   // Shape has the outermost dim at index 0, so need to walk it backwards
561   for (int64_t dim : llvm::reverse(type.getShape()))
562     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
563   return arrayTy;
564 }
565 
566 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
567 struct GlobalMemrefOpLowering
568     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
569   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
570 
571   LogicalResult
572   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
573                   ConversionPatternRewriter &rewriter) const override {
574     MemRefType type = global.type();
575     if (!isConvertibleAndHasIdentityMaps(type))
576       return failure();
577 
578     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
579 
580     LLVM::Linkage linkage =
581         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
582 
583     Attribute initialValue = nullptr;
584     if (!global.isExternal() && !global.isUninitialized()) {
585       auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
586       initialValue = elementsAttr;
587 
588       // For scalar memrefs, the global variable created is of the element type,
589       // so unpack the elements attribute to extract the value.
590       if (type.getRank() == 0)
591         initialValue = elementsAttr.getSplatValue<Attribute>();
592     }
593 
594     uint64_t alignment = global.alignment().getValueOr(0);
595 
596     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
597         global, arrayTy, global.constant(), linkage, global.sym_name(),
598         initialValue, alignment, type.getMemorySpaceAsInt());
599     if (!global.isExternal() && global.isUninitialized()) {
600       Block *blk = new Block();
601       newGlobal.getInitializerRegion().push_back(blk);
602       rewriter.setInsertionPointToStart(blk);
603       Value undef[] = {
604           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
605       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
606     }
607     return success();
608   }
609 };
610 
611 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
612 /// the first element stashed into the descriptor. This reuses
613 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
614 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
615   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
616       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
617                                 converter) {}
618 
619   /// Buffer "allocation" for memref.get_global op is getting the address of
620   /// the global variable referenced.
621   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
622                                           Location loc, Value sizeBytes,
623                                           Operation *op) const override {
624     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
625     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
626     unsigned memSpace = type.getMemorySpaceAsInt();
627 
628     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
629     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
630         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
631 
632     // Get the address of the first element in the array by creating a GEP with
633     // the address of the GV as the base, and (rank + 1) number of 0 indices.
634     Type elementType = typeConverter->convertType(type.getElementType());
635     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
636 
637     SmallVector<Value> operands;
638     operands.insert(operands.end(), type.getRank() + 1,
639                     createIndexConstant(rewriter, loc, 0));
640     auto gep =
641         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
642 
643     // We do not expect the memref obtained using `memref.get_global` to be
644     // ever deallocated. Set the allocated pointer to be known bad value to
645     // help debug if that ever happens.
646     auto intPtrType = getIntPtrType(memSpace);
647     Value deadBeefConst =
648         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
649     auto deadBeefPtr =
650         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
651 
652     // Both allocated and aligned pointers are same. We could potentially stash
653     // a nullptr for the allocated pointer since we do not expect any dealloc.
654     return std::make_tuple(deadBeefPtr, gep);
655   }
656 };
657 
658 // Load operation is lowered to obtaining a pointer to the indexed element
659 // and loading it.
660 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
661   using Base::Base;
662 
663   LogicalResult
664   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
665                   ConversionPatternRewriter &rewriter) const override {
666     auto type = loadOp.getMemRefType();
667 
668     Value dataPtr = getStridedElementPtr(
669         loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter);
670     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
671     return success();
672   }
673 };
674 
675 // Store operation is lowered to obtaining a pointer to the indexed element,
676 // and storing the given value to it.
677 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
678   using Base::Base;
679 
680   LogicalResult
681   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
682                   ConversionPatternRewriter &rewriter) const override {
683     auto type = op.getMemRefType();
684 
685     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(),
686                                          adaptor.indices(), rewriter);
687     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr);
688     return success();
689   }
690 };
691 
692 // The prefetch operation is lowered in a way similar to the load operation
693 // except that the llvm.prefetch operation is used for replacement.
694 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
695   using Base::Base;
696 
697   LogicalResult
698   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
699                   ConversionPatternRewriter &rewriter) const override {
700     auto type = prefetchOp.getMemRefType();
701     auto loc = prefetchOp.getLoc();
702 
703     Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(),
704                                          adaptor.indices(), rewriter);
705 
706     // Replace with llvm.prefetch.
707     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
708     auto isWrite = rewriter.create<LLVM::ConstantOp>(
709         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
710     auto localityHint = rewriter.create<LLVM::ConstantOp>(
711         loc, llvmI32Type,
712         rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
713     auto isData = rewriter.create<LLVM::ConstantOp>(
714         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
715 
716     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
717                                                 localityHint, isData);
718     return success();
719   }
720 };
721 
722 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
723   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
724 
725   LogicalResult
726   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
727                   ConversionPatternRewriter &rewriter) const override {
728     Location loc = op.getLoc();
729     Type operandType = op.memref().getType();
730     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
731       UnrankedMemRefDescriptor desc(adaptor.memref());
732       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
733       return success();
734     }
735     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
736       rewriter.replaceOp(
737           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
738       return success();
739     }
740     return failure();
741   }
742 };
743 
744 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
745   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
746 
747   LogicalResult match(memref::CastOp memRefCastOp) const override {
748     Type srcType = memRefCastOp.getOperand().getType();
749     Type dstType = memRefCastOp.getType();
750 
751     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
752     // used for type erasure. For now they must preserve underlying element type
753     // and require source and result type to have the same rank. Therefore,
754     // perform a sanity check that the underlying structs are the same. Once op
755     // semantics are relaxed we can revisit.
756     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
757       return success(typeConverter->convertType(srcType) ==
758                      typeConverter->convertType(dstType));
759 
760     // At least one of the operands is unranked type
761     assert(srcType.isa<UnrankedMemRefType>() ||
762            dstType.isa<UnrankedMemRefType>());
763 
764     // Unranked to unranked cast is disallowed
765     return !(srcType.isa<UnrankedMemRefType>() &&
766              dstType.isa<UnrankedMemRefType>())
767                ? success()
768                : failure();
769   }
770 
771   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
772                ConversionPatternRewriter &rewriter) const override {
773     auto srcType = memRefCastOp.getOperand().getType();
774     auto dstType = memRefCastOp.getType();
775     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
776     auto loc = memRefCastOp.getLoc();
777 
778     // For ranked/ranked case, just keep the original descriptor.
779     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
780       return rewriter.replaceOp(memRefCastOp, {adaptor.source()});
781 
782     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
783       // Casting ranked to unranked memref type
784       // Set the rank in the destination from the memref type
785       // Allocate space on the stack and copy the src memref descriptor
786       // Set the ptr in the destination to the stack space
787       auto srcMemRefType = srcType.cast<MemRefType>();
788       int64_t rank = srcMemRefType.getRank();
789       // ptr = AllocaOp sizeof(MemRefDescriptor)
790       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
791           loc, adaptor.source(), rewriter);
792       // voidptr = BitCastOp srcType* to void*
793       auto voidPtr =
794           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
795               .getResult();
796       // rank = ConstantOp srcRank
797       auto rankVal = rewriter.create<LLVM::ConstantOp>(
798           loc, getIndexType(), rewriter.getIndexAttr(rank));
799       // undef = UndefOp
800       UnrankedMemRefDescriptor memRefDesc =
801           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
802       // d1 = InsertValueOp undef, rank, 0
803       memRefDesc.setRank(rewriter, loc, rankVal);
804       // d2 = InsertValueOp d1, voidptr, 1
805       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
806       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
807 
808     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
809       // Casting from unranked type to ranked.
810       // The operation is assumed to be doing a correct cast. If the destination
811       // type mismatches the unranked the type, it is undefined behavior.
812       UnrankedMemRefDescriptor memRefDesc(adaptor.source());
813       // ptr = ExtractValueOp src, 1
814       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
815       // castPtr = BitCastOp i8* to structTy*
816       auto castPtr =
817           rewriter
818               .create<LLVM::BitcastOp>(
819                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
820               .getResult();
821       // struct = LoadOp castPtr
822       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
823       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
824     } else {
825       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
826     }
827   }
828 };
829 
830 /// Pattern to lower a `memref.copy` to llvm.
831 ///
832 /// For memrefs with identity layouts, the copy is lowered to the llvm
833 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
834 /// to the generic `MemrefCopyFn`.
835 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
836   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
837 
838   LogicalResult
839   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
840                           ConversionPatternRewriter &rewriter) const {
841     auto loc = op.getLoc();
842     auto srcType = op.source().getType().dyn_cast<MemRefType>();
843 
844     MemRefDescriptor srcDesc(adaptor.source());
845 
846     // Compute number of elements.
847     Value numElements = rewriter.create<LLVM::ConstantOp>(
848         loc, getIndexType(), rewriter.getIndexAttr(1));
849     for (int pos = 0; pos < srcType.getRank(); ++pos) {
850       auto size = srcDesc.size(rewriter, loc, pos);
851       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
852     }
853 
854     // Get element size.
855     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
856     // Compute total.
857     Value totalSize =
858         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
859 
860     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
861     Value srcOffset = srcDesc.offset(rewriter, loc);
862     Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
863                                                 srcBasePtr, srcOffset);
864     MemRefDescriptor targetDesc(adaptor.target());
865     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
866     Value targetOffset = targetDesc.offset(rewriter, loc);
867     Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
868                                                    targetBasePtr, targetOffset);
869     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
870         loc, typeConverter->convertType(rewriter.getI1Type()),
871         rewriter.getBoolAttr(false));
872     rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
873                                     isVolatile);
874     rewriter.eraseOp(op);
875 
876     return success();
877   }
878 
879   LogicalResult
880   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
881                              ConversionPatternRewriter &rewriter) const {
882     auto loc = op.getLoc();
883     auto srcType = op.source().getType().cast<BaseMemRefType>();
884     auto targetType = op.target().getType().cast<BaseMemRefType>();
885 
886     // First make sure we have an unranked memref descriptor representation.
887     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
888       auto rank = rewriter.create<LLVM::ConstantOp>(
889           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
890       auto *typeConverter = getTypeConverter();
891       auto ptr =
892           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
893       auto voidPtr =
894           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
895               .getResult();
896       auto unrankedType =
897           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
898       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
899                                             unrankedType,
900                                             ValueRange{rank, voidPtr});
901     };
902 
903     Value unrankedSource = srcType.hasRank()
904                                ? makeUnranked(adaptor.source(), srcType)
905                                : adaptor.source();
906     Value unrankedTarget = targetType.hasRank()
907                                ? makeUnranked(adaptor.target(), targetType)
908                                : adaptor.target();
909 
910     // Now promote the unranked descriptors to the stack.
911     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
912                                                  rewriter.getIndexAttr(1));
913     auto promote = [&](Value desc) {
914       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
915       auto allocated =
916           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
917       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
918       return allocated;
919     };
920 
921     auto sourcePtr = promote(unrankedSource);
922     auto targetPtr = promote(unrankedTarget);
923 
924     unsigned typeSize =
925         mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
926     auto elemSize = rewriter.create<LLVM::ConstantOp>(
927         loc, getIndexType(), rewriter.getIndexAttr(typeSize));
928     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
929         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
930     rewriter.create<LLVM::CallOp>(loc, copyFn,
931                                   ValueRange{elemSize, sourcePtr, targetPtr});
932     rewriter.eraseOp(op);
933 
934     return success();
935   }
936 
937   LogicalResult
938   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
939                   ConversionPatternRewriter &rewriter) const override {
940     auto srcType = op.source().getType().cast<BaseMemRefType>();
941     auto targetType = op.target().getType().cast<BaseMemRefType>();
942 
943     auto isContiguousMemrefType = [](BaseMemRefType type) {
944       auto memrefType = type.dyn_cast<mlir::MemRefType>();
945       // We can use memcpy for memrefs if they have an identity layout or are
946       // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
947       // special case handled by memrefCopy.
948       return memrefType &&
949              (memrefType.getLayout().isIdentity() ||
950               (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
951                isStaticShapeAndContiguousRowMajor(memrefType)));
952     };
953 
954     if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
955       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
956 
957     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
958   }
959 };
960 
961 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
962 /// memref type. In unranked case, the fields are extracted from the underlying
963 /// ranked descriptor.
964 static void extractPointersAndOffset(Location loc,
965                                      ConversionPatternRewriter &rewriter,
966                                      LLVMTypeConverter &typeConverter,
967                                      Value originalOperand,
968                                      Value convertedOperand,
969                                      Value *allocatedPtr, Value *alignedPtr,
970                                      Value *offset = nullptr) {
971   Type operandType = originalOperand.getType();
972   if (operandType.isa<MemRefType>()) {
973     MemRefDescriptor desc(convertedOperand);
974     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
975     *alignedPtr = desc.alignedPtr(rewriter, loc);
976     if (offset != nullptr)
977       *offset = desc.offset(rewriter, loc);
978     return;
979   }
980 
981   unsigned memorySpace =
982       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
983   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
984   Type llvmElementType = typeConverter.convertType(elementType);
985   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
986       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
987 
988   // Extract pointer to the underlying ranked memref descriptor and cast it to
989   // ElemType**.
990   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
991   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
992 
993   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
994       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
995   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
996       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
997   if (offset != nullptr) {
998     *offset = UnrankedMemRefDescriptor::offset(
999         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1000   }
1001 }
1002 
1003 struct MemRefReinterpretCastOpLowering
1004     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1005   using ConvertOpToLLVMPattern<
1006       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1007 
1008   LogicalResult
1009   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1010                   ConversionPatternRewriter &rewriter) const override {
1011     Type srcType = castOp.source().getType();
1012 
1013     Value descriptor;
1014     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1015                                                adaptor, &descriptor)))
1016       return failure();
1017     rewriter.replaceOp(castOp, {descriptor});
1018     return success();
1019   }
1020 
1021 private:
1022   LogicalResult convertSourceMemRefToDescriptor(
1023       ConversionPatternRewriter &rewriter, Type srcType,
1024       memref::ReinterpretCastOp castOp,
1025       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1026     MemRefType targetMemRefType =
1027         castOp.getResult().getType().cast<MemRefType>();
1028     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1029                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1030     if (!llvmTargetDescriptorTy)
1031       return failure();
1032 
1033     // Create descriptor.
1034     Location loc = castOp.getLoc();
1035     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1036 
1037     // Set allocated and aligned pointers.
1038     Value allocatedPtr, alignedPtr;
1039     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1040                              castOp.source(), adaptor.source(), &allocatedPtr,
1041                              &alignedPtr);
1042     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1043     desc.setAlignedPtr(rewriter, loc, alignedPtr);
1044 
1045     // Set offset.
1046     if (castOp.isDynamicOffset(0))
1047       desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
1048     else
1049       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1050 
1051     // Set sizes and strides.
1052     unsigned dynSizeId = 0;
1053     unsigned dynStrideId = 0;
1054     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1055       if (castOp.isDynamicSize(i))
1056         desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
1057       else
1058         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1059 
1060       if (castOp.isDynamicStride(i))
1061         desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
1062       else
1063         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1064     }
1065     *descriptor = desc;
1066     return success();
1067   }
1068 };
1069 
1070 struct MemRefReshapeOpLowering
1071     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1072   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1073 
1074   LogicalResult
1075   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1076                   ConversionPatternRewriter &rewriter) const override {
1077     Type srcType = reshapeOp.source().getType();
1078 
1079     Value descriptor;
1080     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1081                                                adaptor, &descriptor)))
1082       return failure();
1083     rewriter.replaceOp(reshapeOp, {descriptor});
1084     return success();
1085   }
1086 
1087 private:
1088   LogicalResult
1089   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1090                                   Type srcType, memref::ReshapeOp reshapeOp,
1091                                   memref::ReshapeOp::Adaptor adaptor,
1092                                   Value *descriptor) const {
1093     // Conversion for statically-known shape args is performed via
1094     // `memref_reinterpret_cast`.
1095     auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
1096     if (shapeMemRefType.hasStaticShape())
1097       return failure();
1098 
1099     // The shape is a rank-1 tensor with unknown length.
1100     Location loc = reshapeOp.getLoc();
1101     MemRefDescriptor shapeDesc(adaptor.shape());
1102     Value resultRank = shapeDesc.size(rewriter, loc, 0);
1103 
1104     // Extract address space and element type.
1105     auto targetType =
1106         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1107     unsigned addressSpace = targetType.getMemorySpaceAsInt();
1108     Type elementType = targetType.getElementType();
1109 
1110     // Create the unranked memref descriptor that holds the ranked one. The
1111     // inner descriptor is allocated on stack.
1112     auto targetDesc = UnrankedMemRefDescriptor::undef(
1113         rewriter, loc, typeConverter->convertType(targetType));
1114     targetDesc.setRank(rewriter, loc, resultRank);
1115     SmallVector<Value, 4> sizes;
1116     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1117                                            targetDesc, sizes);
1118     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1119         loc, getVoidPtrType(), sizes.front(), llvm::None);
1120     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1121 
1122     // Extract pointers and offset from the source memref.
1123     Value allocatedPtr, alignedPtr, offset;
1124     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1125                              reshapeOp.source(), adaptor.source(),
1126                              &allocatedPtr, &alignedPtr, &offset);
1127 
1128     // Set pointers and offset.
1129     Type llvmElementType = typeConverter->convertType(elementType);
1130     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1131         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1132     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1133                                               elementPtrPtrType, allocatedPtr);
1134     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1135                                             underlyingDescPtr,
1136                                             elementPtrPtrType, alignedPtr);
1137     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1138                                         underlyingDescPtr, elementPtrPtrType,
1139                                         offset);
1140 
1141     // Use the offset pointer as base for further addressing. Copy over the new
1142     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1143     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1144         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1145         elementPtrPtrType);
1146     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1147         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1148     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1149     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1150     Value resultRankMinusOne =
1151         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1152 
1153     Block *initBlock = rewriter.getInsertionBlock();
1154     Type indexType = getTypeConverter()->getIndexType();
1155     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1156 
1157     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1158                                             {indexType, indexType}, {loc, loc});
1159 
1160     // Move the remaining initBlock ops to condBlock.
1161     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1162     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1163 
1164     rewriter.setInsertionPointToEnd(initBlock);
1165     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1166                                 condBlock);
1167     rewriter.setInsertionPointToStart(condBlock);
1168     Value indexArg = condBlock->getArgument(0);
1169     Value strideArg = condBlock->getArgument(1);
1170 
1171     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1172     Value pred = rewriter.create<LLVM::ICmpOp>(
1173         loc, IntegerType::get(rewriter.getContext(), 1),
1174         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1175 
1176     Block *bodyBlock =
1177         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1178     rewriter.setInsertionPointToStart(bodyBlock);
1179 
1180     // Copy size from shape to descriptor.
1181     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1182     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1183         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1184     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1185     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1186                                       targetSizesBase, indexArg, size);
1187 
1188     // Write stride value and compute next one.
1189     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1190                                         targetStridesBase, indexArg, strideArg);
1191     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1192 
1193     // Decrement loop counter and branch back.
1194     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1195     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1196                                 condBlock);
1197 
1198     Block *remainder =
1199         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1200 
1201     // Hook up the cond exit to the remainder.
1202     rewriter.setInsertionPointToEnd(condBlock);
1203     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1204                                     llvm::None);
1205 
1206     // Reset position to beginning of new remainder block.
1207     rewriter.setInsertionPointToStart(remainder);
1208 
1209     *descriptor = targetDesc;
1210     return success();
1211   }
1212 };
1213 
1214 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1215 /// `Value`s.
1216 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1217                                       Type &llvmIndexType,
1218                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1219   return llvm::to_vector<4>(
1220       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1221         if (auto attr = value.dyn_cast<Attribute>())
1222           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1223         return value.get<Value>();
1224       }));
1225 }
1226 
1227 /// Compute a map that for a given dimension of the expanded type gives the
1228 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1229 /// the `reassocation` maps.
1230 static DenseMap<int64_t, int64_t>
1231 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1232   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1233   for (auto &en : enumerate(reassociation)) {
1234     for (auto dim : en.value())
1235       expandedDimToCollapsedDim[dim] = en.index();
1236   }
1237   return expandedDimToCollapsedDim;
1238 }
1239 
1240 static OpFoldResult
1241 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1242                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1243                          MemRefDescriptor &inDesc,
1244                          ArrayRef<int64_t> inStaticShape,
1245                          ArrayRef<ReassociationIndices> reassocation,
1246                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1247   int64_t outDimSize = outStaticShape[outDimIndex];
1248   if (!ShapedType::isDynamic(outDimSize))
1249     return b.getIndexAttr(outDimSize);
1250 
1251   // Calculate the multiplication of all the out dim sizes except the
1252   // current dim.
1253   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1254   int64_t otherDimSizesMul = 1;
1255   for (auto otherDimIndex : reassocation[inDimIndex]) {
1256     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1257       continue;
1258     int64_t otherDimSize = outStaticShape[otherDimIndex];
1259     assert(!ShapedType::isDynamic(otherDimSize) &&
1260            "single dimension cannot be expanded into multiple dynamic "
1261            "dimensions");
1262     otherDimSizesMul *= otherDimSize;
1263   }
1264 
1265   // outDimSize = inDimSize / otherOutDimSizesMul
1266   int64_t inDimSize = inStaticShape[inDimIndex];
1267   Value inDimSizeDynamic =
1268       ShapedType::isDynamic(inDimSize)
1269           ? inDesc.size(b, loc, inDimIndex)
1270           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1271                                        b.getIndexAttr(inDimSize));
1272   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1273       loc, inDimSizeDynamic,
1274       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1275                                  b.getIndexAttr(otherDimSizesMul)));
1276   return outDimSizeDynamic;
1277 }
1278 
1279 static OpFoldResult getCollapsedOutputDimSize(
1280     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1281     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1282     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1283   if (!ShapedType::isDynamic(outDimSize))
1284     return b.getIndexAttr(outDimSize);
1285 
1286   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1287   Value outDimSizeDynamic = c1;
1288   for (auto inDimIndex : reassocation[outDimIndex]) {
1289     int64_t inDimSize = inStaticShape[inDimIndex];
1290     Value inDimSizeDynamic =
1291         ShapedType::isDynamic(inDimSize)
1292             ? inDesc.size(b, loc, inDimIndex)
1293             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1294                                          b.getIndexAttr(inDimSize));
1295     outDimSizeDynamic =
1296         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1297   }
1298   return outDimSizeDynamic;
1299 }
1300 
1301 static SmallVector<OpFoldResult, 4>
1302 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1303                         ArrayRef<ReassociationIndices> reassocation,
1304                         ArrayRef<int64_t> inStaticShape,
1305                         MemRefDescriptor &inDesc,
1306                         ArrayRef<int64_t> outStaticShape) {
1307   return llvm::to_vector<4>(llvm::map_range(
1308       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1309         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1310                                          outStaticShape[outDimIndex],
1311                                          inStaticShape, inDesc, reassocation);
1312       }));
1313 }
1314 
1315 static SmallVector<OpFoldResult, 4>
1316 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1317                        ArrayRef<ReassociationIndices> reassocation,
1318                        ArrayRef<int64_t> inStaticShape,
1319                        MemRefDescriptor &inDesc,
1320                        ArrayRef<int64_t> outStaticShape) {
1321   DenseMap<int64_t, int64_t> outDimToInDimMap =
1322       getExpandedDimToCollapsedDimMap(reassocation);
1323   return llvm::to_vector<4>(llvm::map_range(
1324       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1325         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1326                                         outStaticShape, inDesc, inStaticShape,
1327                                         reassocation, outDimToInDimMap);
1328       }));
1329 }
1330 
1331 static SmallVector<Value>
1332 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1333                       ArrayRef<ReassociationIndices> reassocation,
1334                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1335                       ArrayRef<int64_t> outStaticShape) {
1336   return outStaticShape.size() < inStaticShape.size()
1337              ? getAsValues(b, loc, llvmIndexType,
1338                            getCollapsedOutputShape(b, loc, llvmIndexType,
1339                                                    reassocation, inStaticShape,
1340                                                    inDesc, outStaticShape))
1341              : getAsValues(b, loc, llvmIndexType,
1342                            getExpandedOutputShape(b, loc, llvmIndexType,
1343                                                   reassocation, inStaticShape,
1344                                                   inDesc, outStaticShape));
1345 }
1346 
1347 // ReshapeOp creates a new view descriptor of the proper rank.
1348 // For now, the only conversion supported is for target MemRef with static sizes
1349 // and strides.
1350 template <typename ReshapeOp>
1351 class ReassociatingReshapeOpConversion
1352     : public ConvertOpToLLVMPattern<ReshapeOp> {
1353 public:
1354   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1355   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1356 
1357   LogicalResult
1358   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1359                   ConversionPatternRewriter &rewriter) const override {
1360     MemRefType dstType = reshapeOp.getResultType();
1361     MemRefType srcType = reshapeOp.getSrcType();
1362 
1363     // The condition on the layouts can be ignored when all shapes are static.
1364     if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) {
1365       if (!srcType.getLayout().isIdentity() ||
1366           !dstType.getLayout().isIdentity()) {
1367         return rewriter.notifyMatchFailure(
1368             reshapeOp, "only empty layout map is supported");
1369       }
1370     }
1371 
1372     int64_t offset;
1373     SmallVector<int64_t, 4> strides;
1374     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1375       return rewriter.notifyMatchFailure(
1376           reshapeOp, "failed to get stride and offset exprs");
1377     }
1378 
1379     MemRefDescriptor srcDesc(adaptor.src());
1380     Location loc = reshapeOp->getLoc();
1381     auto dstDesc = MemRefDescriptor::undef(
1382         rewriter, loc, this->typeConverter->convertType(dstType));
1383     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1384     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1385     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1386 
1387     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1388     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1389     Type llvmIndexType =
1390         this->typeConverter->convertType(rewriter.getIndexType());
1391     SmallVector<Value> dstShape = getDynamicOutputShape(
1392         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1393         srcStaticShape, srcDesc, dstStaticShape);
1394     for (auto &en : llvm::enumerate(dstShape))
1395       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1396 
1397     auto isStaticStride = [](int64_t stride) {
1398       return !ShapedType::isDynamicStrideOrOffset(stride);
1399     };
1400     if (llvm::all_of(strides, isStaticStride)) {
1401       for (auto &en : llvm::enumerate(strides))
1402         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1403     } else {
1404       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1405                                                    rewriter.getIndexAttr(1));
1406       Value stride = c1;
1407       for (auto dimIndex :
1408            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1409         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1410         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1411       }
1412     }
1413     rewriter.replaceOp(reshapeOp, {dstDesc});
1414     return success();
1415   }
1416 };
1417 
1418 /// Conversion pattern that transforms a subview op into:
1419 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1420 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1421 ///      and stride.
1422 /// The subview op is replaced by the descriptor.
1423 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1424   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1425 
1426   LogicalResult
1427   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1428                   ConversionPatternRewriter &rewriter) const override {
1429     auto loc = subViewOp.getLoc();
1430 
1431     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
1432     auto sourceElementTy =
1433         typeConverter->convertType(sourceMemRefType.getElementType());
1434 
1435     auto viewMemRefType = subViewOp.getType();
1436     auto inferredType = memref::SubViewOp::inferResultType(
1437                             subViewOp.getSourceType(),
1438                             extractFromI64ArrayAttr(subViewOp.static_offsets()),
1439                             extractFromI64ArrayAttr(subViewOp.static_sizes()),
1440                             extractFromI64ArrayAttr(subViewOp.static_strides()))
1441                             .cast<MemRefType>();
1442     auto targetElementTy =
1443         typeConverter->convertType(viewMemRefType.getElementType());
1444     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1445     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1446         !LLVM::isCompatibleType(sourceElementTy) ||
1447         !LLVM::isCompatibleType(targetElementTy) ||
1448         !LLVM::isCompatibleType(targetDescTy))
1449       return failure();
1450 
1451     // Extract the offset and strides from the type.
1452     int64_t offset;
1453     SmallVector<int64_t, 4> strides;
1454     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1455     if (failed(successStrides))
1456       return failure();
1457 
1458     // Create the descriptor.
1459     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1460       return failure();
1461     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1462     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1463 
1464     // Copy the buffer pointer from the old descriptor to the new one.
1465     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1466     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1467         loc,
1468         LLVM::LLVMPointerType::get(targetElementTy,
1469                                    viewMemRefType.getMemorySpaceAsInt()),
1470         extracted);
1471     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1472 
1473     // Copy the aligned pointer from the old descriptor to the new one.
1474     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1475     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1476         loc,
1477         LLVM::LLVMPointerType::get(targetElementTy,
1478                                    viewMemRefType.getMemorySpaceAsInt()),
1479         extracted);
1480     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1481 
1482     size_t inferredShapeRank = inferredType.getRank();
1483     size_t resultShapeRank = viewMemRefType.getRank();
1484 
1485     // Extract strides needed to compute offset.
1486     SmallVector<Value, 4> strideValues;
1487     strideValues.reserve(inferredShapeRank);
1488     for (unsigned i = 0; i < inferredShapeRank; ++i)
1489       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1490 
1491     // Offset.
1492     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1493     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1494       targetMemRef.setConstantOffset(rewriter, loc, offset);
1495     } else {
1496       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1497       // `inferredShapeRank` may be larger than the number of offset operands
1498       // because of trailing semantics. In this case, the offset is guaranteed
1499       // to be interpreted as 0 and we can just skip the extra dimensions.
1500       for (unsigned i = 0, e = std::min(inferredShapeRank,
1501                                         subViewOp.getMixedOffsets().size());
1502            i < e; ++i) {
1503         Value offset =
1504             // TODO: need OpFoldResult ODS adaptor to clean this up.
1505             subViewOp.isDynamicOffset(i)
1506                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1507                 : rewriter.create<LLVM::ConstantOp>(
1508                       loc, llvmIndexType,
1509                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1510         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1511         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1512       }
1513       targetMemRef.setOffset(rewriter, loc, baseOffset);
1514     }
1515 
1516     // Update sizes and strides.
1517     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1518     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1519     assert(mixedSizes.size() == mixedStrides.size() &&
1520            "expected sizes and strides of equal length");
1521     llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1522     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1523          i >= 0 && j >= 0; --i) {
1524       if (unusedDims.test(i))
1525         continue;
1526 
1527       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1528       // In this case, the size is guaranteed to be interpreted as Dim and the
1529       // stride as 1.
1530       Value size, stride;
1531       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1532         // If the static size is available, use it directly. This is similar to
1533         // the folding of dim(constant-op) but removes the need for dim to be
1534         // aware of LLVM constants and for this pass to be aware of std
1535         // constants.
1536         int64_t staticSize =
1537             subViewOp.source().getType().cast<MemRefType>().getShape()[i];
1538         if (staticSize != ShapedType::kDynamicSize) {
1539           size = rewriter.create<LLVM::ConstantOp>(
1540               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1541         } else {
1542           Value pos = rewriter.create<LLVM::ConstantOp>(
1543               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1544           Value dim =
1545               rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos);
1546           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1547               loc, llvmIndexType, dim);
1548           size = cast.getResult(0);
1549         }
1550         stride = rewriter.create<LLVM::ConstantOp>(
1551             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1552       } else {
1553         // TODO: need OpFoldResult ODS adaptor to clean this up.
1554         size =
1555             subViewOp.isDynamicSize(i)
1556                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1557                 : rewriter.create<LLVM::ConstantOp>(
1558                       loc, llvmIndexType,
1559                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1560         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1561           stride = rewriter.create<LLVM::ConstantOp>(
1562               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1563         } else {
1564           stride =
1565               subViewOp.isDynamicStride(i)
1566                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1567                   : rewriter.create<LLVM::ConstantOp>(
1568                         loc, llvmIndexType,
1569                         rewriter.getI64IntegerAttr(
1570                             subViewOp.getStaticStride(i)));
1571           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1572         }
1573       }
1574       targetMemRef.setSize(rewriter, loc, j, size);
1575       targetMemRef.setStride(rewriter, loc, j, stride);
1576       j--;
1577     }
1578 
1579     rewriter.replaceOp(subViewOp, {targetMemRef});
1580     return success();
1581   }
1582 };
1583 
1584 /// Conversion pattern that transforms a transpose op into:
1585 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1586 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1587 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1588 ///      and stride. Size and stride are permutations of the original values.
1589 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1590 /// The transpose op is replaced by the alloca'ed pointer.
1591 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1592 public:
1593   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1594 
1595   LogicalResult
1596   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1597                   ConversionPatternRewriter &rewriter) const override {
1598     auto loc = transposeOp.getLoc();
1599     MemRefDescriptor viewMemRef(adaptor.in());
1600 
1601     // No permutation, early exit.
1602     if (transposeOp.permutation().isIdentity())
1603       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1604 
1605     auto targetMemRef = MemRefDescriptor::undef(
1606         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1607 
1608     // Copy the base and aligned pointers from the old descriptor to the new
1609     // one.
1610     targetMemRef.setAllocatedPtr(rewriter, loc,
1611                                  viewMemRef.allocatedPtr(rewriter, loc));
1612     targetMemRef.setAlignedPtr(rewriter, loc,
1613                                viewMemRef.alignedPtr(rewriter, loc));
1614 
1615     // Copy the offset pointer from the old descriptor to the new one.
1616     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1617 
1618     // Iterate over the dimensions and apply size/stride permutation.
1619     for (const auto &en :
1620          llvm::enumerate(transposeOp.permutation().getResults())) {
1621       int sourcePos = en.index();
1622       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1623       targetMemRef.setSize(rewriter, loc, targetPos,
1624                            viewMemRef.size(rewriter, loc, sourcePos));
1625       targetMemRef.setStride(rewriter, loc, targetPos,
1626                              viewMemRef.stride(rewriter, loc, sourcePos));
1627     }
1628 
1629     rewriter.replaceOp(transposeOp, {targetMemRef});
1630     return success();
1631   }
1632 };
1633 
1634 /// Conversion pattern that transforms an op into:
1635 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1636 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1637 ///      and stride.
1638 /// The view op is replaced by the descriptor.
1639 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1640   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1641 
1642   // Build and return the value for the idx^th shape dimension, either by
1643   // returning the constant shape dimension or counting the proper dynamic size.
1644   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1645                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1646                 unsigned idx) const {
1647     assert(idx < shape.size());
1648     if (!ShapedType::isDynamic(shape[idx]))
1649       return createIndexConstant(rewriter, loc, shape[idx]);
1650     // Count the number of dynamic dims in range [0, idx]
1651     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1652       return ShapedType::isDynamic(v);
1653     });
1654     return dynamicSizes[nDynamic];
1655   }
1656 
1657   // Build and return the idx^th stride, either by returning the constant stride
1658   // or by computing the dynamic stride from the current `runningStride` and
1659   // `nextSize`. The caller should keep a running stride and update it with the
1660   // result returned by this function.
1661   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1662                   ArrayRef<int64_t> strides, Value nextSize,
1663                   Value runningStride, unsigned idx) const {
1664     assert(idx < strides.size());
1665     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1666       return createIndexConstant(rewriter, loc, strides[idx]);
1667     if (nextSize)
1668       return runningStride
1669                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1670                  : nextSize;
1671     assert(!runningStride);
1672     return createIndexConstant(rewriter, loc, 1);
1673   }
1674 
1675   LogicalResult
1676   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1677                   ConversionPatternRewriter &rewriter) const override {
1678     auto loc = viewOp.getLoc();
1679 
1680     auto viewMemRefType = viewOp.getType();
1681     auto targetElementTy =
1682         typeConverter->convertType(viewMemRefType.getElementType());
1683     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1684     if (!targetDescTy || !targetElementTy ||
1685         !LLVM::isCompatibleType(targetElementTy) ||
1686         !LLVM::isCompatibleType(targetDescTy))
1687       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1688              failure();
1689 
1690     int64_t offset;
1691     SmallVector<int64_t, 4> strides;
1692     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1693     if (failed(successStrides))
1694       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1695     assert(offset == 0 && "expected offset to be 0");
1696 
1697     // Create the descriptor.
1698     MemRefDescriptor sourceMemRef(adaptor.source());
1699     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1700 
1701     // Field 1: Copy the allocated pointer, used for malloc/free.
1702     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1703     auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
1704     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1705         loc,
1706         LLVM::LLVMPointerType::get(targetElementTy,
1707                                    srcMemRefType.getMemorySpaceAsInt()),
1708         allocatedPtr);
1709     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1710 
1711     // Field 2: Copy the actual aligned pointer to payload.
1712     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1713     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
1714                                               alignedPtr, adaptor.byte_shift());
1715     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1716         loc,
1717         LLVM::LLVMPointerType::get(targetElementTy,
1718                                    srcMemRefType.getMemorySpaceAsInt()),
1719         alignedPtr);
1720     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1721 
1722     // Field 3: The offset in the resulting type must be 0. This is because of
1723     // the type change: an offset on srcType* may not be expressible as an
1724     // offset on dstType*.
1725     targetMemRef.setOffset(rewriter, loc,
1726                            createIndexConstant(rewriter, loc, offset));
1727 
1728     // Early exit for 0-D corner case.
1729     if (viewMemRefType.getRank() == 0)
1730       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1731 
1732     // Fields 4 and 5: Update sizes and strides.
1733     if (strides.back() != 1)
1734       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1735              failure();
1736     Value stride = nullptr, nextSize = nullptr;
1737     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1738       // Update size.
1739       Value size =
1740           getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
1741       targetMemRef.setSize(rewriter, loc, i, size);
1742       // Update stride.
1743       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1744       targetMemRef.setStride(rewriter, loc, i, stride);
1745       nextSize = size;
1746     }
1747 
1748     rewriter.replaceOp(viewOp, {targetMemRef});
1749     return success();
1750   }
1751 };
1752 
1753 //===----------------------------------------------------------------------===//
1754 // AtomicRMWOpLowering
1755 //===----------------------------------------------------------------------===//
1756 
1757 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1758 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1759 static Optional<LLVM::AtomicBinOp>
1760 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1761   switch (atomicOp.kind()) {
1762   case arith::AtomicRMWKind::addf:
1763     return LLVM::AtomicBinOp::fadd;
1764   case arith::AtomicRMWKind::addi:
1765     return LLVM::AtomicBinOp::add;
1766   case arith::AtomicRMWKind::assign:
1767     return LLVM::AtomicBinOp::xchg;
1768   case arith::AtomicRMWKind::maxs:
1769     return LLVM::AtomicBinOp::max;
1770   case arith::AtomicRMWKind::maxu:
1771     return LLVM::AtomicBinOp::umax;
1772   case arith::AtomicRMWKind::mins:
1773     return LLVM::AtomicBinOp::min;
1774   case arith::AtomicRMWKind::minu:
1775     return LLVM::AtomicBinOp::umin;
1776   case arith::AtomicRMWKind::ori:
1777     return LLVM::AtomicBinOp::_or;
1778   case arith::AtomicRMWKind::andi:
1779     return LLVM::AtomicBinOp::_and;
1780   default:
1781     return llvm::None;
1782   }
1783   llvm_unreachable("Invalid AtomicRMWKind");
1784 }
1785 
1786 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1787   using Base::Base;
1788 
1789   LogicalResult
1790   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1791                   ConversionPatternRewriter &rewriter) const override {
1792     if (failed(match(atomicOp)))
1793       return failure();
1794     auto maybeKind = matchSimpleAtomicOp(atomicOp);
1795     if (!maybeKind)
1796       return failure();
1797     auto resultType = adaptor.value().getType();
1798     auto memRefType = atomicOp.getMemRefType();
1799     auto dataPtr =
1800         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
1801                              adaptor.indices(), rewriter);
1802     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1803         atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
1804         LLVM::AtomicOrdering::acq_rel);
1805     return success();
1806   }
1807 };
1808 
1809 } // namespace
1810 
1811 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
1812                                                   RewritePatternSet &patterns) {
1813   // clang-format off
1814   patterns.add<
1815       AllocaOpLowering,
1816       AllocaScopeOpLowering,
1817       AtomicRMWOpLowering,
1818       AssumeAlignmentOpLowering,
1819       DimOpLowering,
1820       GenericAtomicRMWOpLowering,
1821       GlobalMemrefOpLowering,
1822       GetGlobalMemrefOpLowering,
1823       LoadOpLowering,
1824       MemRefCastOpLowering,
1825       MemRefCopyOpLowering,
1826       MemRefReinterpretCastOpLowering,
1827       MemRefReshapeOpLowering,
1828       PrefetchOpLowering,
1829       RankOpLowering,
1830       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1831       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1832       StoreOpLowering,
1833       SubViewOpLowering,
1834       TransposeOpLowering,
1835       ViewOpLowering>(converter);
1836   // clang-format on
1837   auto allocLowering = converter.getOptions().allocLowering;
1838   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1839     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1840   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1841     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1842 }
1843 
1844 namespace {
1845 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
1846   MemRefToLLVMPass() = default;
1847 
1848   void runOnOperation() override {
1849     Operation *op = getOperation();
1850     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1851     LowerToLLVMOptions options(&getContext(),
1852                                dataLayoutAnalysis.getAtOrAbove(op));
1853     options.allocLowering =
1854         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1855                          : LowerToLLVMOptions::AllocLowering::Malloc);
1856     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1857       options.overrideIndexBitwidth(indexBitwidth);
1858 
1859     LLVMTypeConverter typeConverter(&getContext(), options,
1860                                     &dataLayoutAnalysis);
1861     RewritePatternSet patterns(&getContext());
1862     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
1863     LLVMConversionTarget target(getContext());
1864     target.addLegalOp<func::FuncOp>();
1865     if (failed(applyPartialConversion(op, target, std::move(patterns))))
1866       signalPassFailure();
1867   }
1868 };
1869 } // namespace
1870 
1871 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
1872   return std::make_unique<MemRefToLLVMPass>();
1873 }
1874