1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 
29 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
30   AllocOpLowering(LLVMTypeConverter &converter)
31       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
32                                 converter) {}
33 
34   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
35                                           Location loc, Value sizeBytes,
36                                           Operation *op) const override {
37     // Heap allocations.
38     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
39     MemRefType memRefType = allocOp.getType();
40 
41     Value alignment;
42     if (auto alignmentAttr = allocOp.alignment()) {
43       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
44     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
45       // In the case where no alignment is specified, we may want to override
46       // `malloc's` behavior. `malloc` typically aligns at the size of the
47       // biggest scalar on a target HW. For non-scalars, use the natural
48       // alignment of the LLVM type given by the LLVM DataLayout.
49       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
50     }
51 
52     if (alignment) {
53       // Adjust the allocation size to consider alignment.
54       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
55     }
56 
57     // Allocate the underlying buffer and store a pointer to it in the MemRef
58     // descriptor.
59     Type elementPtrType = this->getElementPtrType(memRefType);
60     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
61         allocOp->getParentOfType<ModuleOp>(), getIndexType());
62     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
63                                   getVoidPtrType());
64     Value allocatedPtr =
65         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
66 
67     Value alignedPtr = allocatedPtr;
68     if (alignment) {
69       // Compute the aligned type pointer.
70       Value allocatedInt =
71           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
72       Value alignmentInt =
73           createAligned(rewriter, loc, allocatedInt, alignment);
74       alignedPtr =
75           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
76     }
77 
78     return std::make_tuple(allocatedPtr, alignedPtr);
79   }
80 };
81 
82 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
83   AlignedAllocOpLowering(LLVMTypeConverter &converter)
84       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
85                                 converter) {}
86 
87   /// Returns the memref's element size in bytes using the data layout active at
88   /// `op`.
89   // TODO: there are other places where this is used. Expose publicly?
90   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
91     const DataLayout *layout = &defaultLayout;
92     if (const DataLayoutAnalysis *analysis =
93             getTypeConverter()->getDataLayoutAnalysis()) {
94       layout = &analysis->getAbove(op);
95     }
96     Type elementType = memRefType.getElementType();
97     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
98       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
99                                                          *layout);
100     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
101       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
102           memRefElementType, *layout);
103     return layout->getTypeSize(elementType);
104   }
105 
106   /// Returns true if the memref size in bytes is known to be a multiple of
107   /// factor assuming the data layout active at `op`.
108   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
109                               Operation *op) const {
110     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
111     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
112       if (ShapedType::isDynamic(type.getDimSize(i)))
113         continue;
114       sizeDivisor = sizeDivisor * type.getDimSize(i);
115     }
116     return sizeDivisor % factor == 0;
117   }
118 
119   /// Returns the alignment to be used for the allocation call itself.
120   /// aligned_alloc requires the allocation size to be a power of two, and the
121   /// allocation size to be a multiple of alignment,
122   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
123     if (Optional<uint64_t> alignment = allocOp.alignment())
124       return *alignment;
125 
126     // Whenever we don't have alignment set, we will use an alignment
127     // consistent with the element type; since the allocation size has to be a
128     // power of two, we will bump to the next power of two if it already isn't.
129     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
130     return std::max(kMinAlignedAllocAlignment,
131                     llvm::PowerOf2Ceil(eltSizeBytes));
132   }
133 
134   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
135                                           Location loc, Value sizeBytes,
136                                           Operation *op) const override {
137     // Heap allocations.
138     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
139     MemRefType memRefType = allocOp.getType();
140     int64_t alignment = getAllocationAlignment(allocOp);
141     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
142 
143     // aligned_alloc requires size to be a multiple of alignment; we will pad
144     // the size to the next multiple if necessary.
145     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
146       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
147 
148     Type elementPtrType = this->getElementPtrType(memRefType);
149     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
150         allocOp->getParentOfType<ModuleOp>(), getIndexType());
151     auto results =
152         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
153                        getVoidPtrType());
154     Value allocatedPtr =
155         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
156 
157     return std::make_tuple(allocatedPtr, allocatedPtr);
158   }
159 
160   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
161   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
162 
163   /// Default layout to use in absence of the corresponding analysis.
164   DataLayout defaultLayout;
165 };
166 
167 // Out of line definition, required till C++17.
168 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
169 
170 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
171   AllocaOpLowering(LLVMTypeConverter &converter)
172       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
173                                 converter) {}
174 
175   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
176   /// is set to null for stack allocations. `accessAlignment` is set if
177   /// alignment is needed post allocation (for eg. in conjunction with malloc).
178   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
179                                           Location loc, Value sizeBytes,
180                                           Operation *op) const override {
181 
182     // With alloca, one gets a pointer to the element type right away.
183     // For stack allocations.
184     auto allocaOp = cast<memref::AllocaOp>(op);
185     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
186 
187     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
188         loc, elementPtrType, sizeBytes,
189         allocaOp.alignment() ? *allocaOp.alignment() : 0);
190 
191     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
192   }
193 };
194 
195 struct AllocaScopeOpLowering
196     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
197   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
198 
199   LogicalResult
200   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
201                   ConversionPatternRewriter &rewriter) const override {
202     OpBuilder::InsertionGuard guard(rewriter);
203     Location loc = allocaScopeOp.getLoc();
204 
205     // Split the current block before the AllocaScopeOp to create the inlining
206     // point.
207     auto *currentBlock = rewriter.getInsertionBlock();
208     auto *remainingOpsBlock =
209         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
210     Block *continueBlock;
211     if (allocaScopeOp.getNumResults() == 0) {
212       continueBlock = remainingOpsBlock;
213     } else {
214       continueBlock = rewriter.createBlock(
215           remainingOpsBlock, allocaScopeOp.getResultTypes(),
216           SmallVector<Location>(allocaScopeOp->getNumResults(),
217                                 allocaScopeOp.getLoc()));
218       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
219     }
220 
221     // Inline body region.
222     Block *beforeBody = &allocaScopeOp.bodyRegion().front();
223     Block *afterBody = &allocaScopeOp.bodyRegion().back();
224     rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
225 
226     // Save stack and then branch into the body of the region.
227     rewriter.setInsertionPointToEnd(currentBlock);
228     auto stackSaveOp =
229         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
230     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
231 
232     // Replace the alloca_scope return with a branch that jumps out of the body.
233     // Stack restore before leaving the body region.
234     rewriter.setInsertionPointToEnd(afterBody);
235     auto returnOp =
236         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
237     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
238         returnOp, returnOp.results(), continueBlock);
239 
240     // Insert stack restore before jumping out the body of the region.
241     rewriter.setInsertionPoint(branchOp);
242     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
243 
244     // Replace the op with values return from the body region.
245     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
246 
247     return success();
248   }
249 };
250 
251 struct AssumeAlignmentOpLowering
252     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
253   using ConvertOpToLLVMPattern<
254       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
255 
256   LogicalResult
257   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
258                   ConversionPatternRewriter &rewriter) const override {
259     Value memref = adaptor.memref();
260     unsigned alignment = op.alignment();
261     auto loc = op.getLoc();
262 
263     MemRefDescriptor memRefDescriptor(memref);
264     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
265 
266     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
267     // the asserted memref.alignedPtr isn't used anywhere else, as the real
268     // users like load/store/views always re-extract memref.alignedPtr as they
269     // get lowered.
270     //
271     // This relies on LLVM's CSE optimization (potentially after SROA), since
272     // after CSE all memref.alignedPtr instances get de-duplicated into the same
273     // pointer SSA value.
274     auto intPtrType =
275         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
276     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
277     Value mask =
278         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
279     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
280     rewriter.create<LLVM::AssumeOp>(
281         loc, rewriter.create<LLVM::ICmpOp>(
282                  loc, LLVM::ICmpPredicate::eq,
283                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
284 
285     rewriter.eraseOp(op);
286     return success();
287   }
288 };
289 
290 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
291 // The memref descriptor being an SSA value, there is no need to clean it up
292 // in any way.
293 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
294   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
295 
296   explicit DeallocOpLowering(LLVMTypeConverter &converter)
297       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
298 
299   LogicalResult
300   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
301                   ConversionPatternRewriter &rewriter) const override {
302     // Insert the `free` declaration if it is not already present.
303     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
304     MemRefDescriptor memref(adaptor.memref());
305     Value casted = rewriter.create<LLVM::BitcastOp>(
306         op.getLoc(), getVoidPtrType(),
307         memref.allocatedPtr(rewriter, op.getLoc()));
308     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
309         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
310     return success();
311   }
312 };
313 
314 // A `dim` is converted to a constant for static sizes and to an access to the
315 // size stored in the memref descriptor for dynamic sizes.
316 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
317   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
318 
319   LogicalResult
320   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
321                   ConversionPatternRewriter &rewriter) const override {
322     Type operandType = dimOp.source().getType();
323     if (operandType.isa<UnrankedMemRefType>()) {
324       rewriter.replaceOp(
325           dimOp, {extractSizeOfUnrankedMemRef(
326                      operandType, dimOp, adaptor.getOperands(), rewriter)});
327 
328       return success();
329     }
330     if (operandType.isa<MemRefType>()) {
331       rewriter.replaceOp(
332           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
333                                             adaptor.getOperands(), rewriter)});
334       return success();
335     }
336     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
337   }
338 
339 private:
340   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
341                                     OpAdaptor adaptor,
342                                     ConversionPatternRewriter &rewriter) const {
343     Location loc = dimOp.getLoc();
344 
345     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
346     auto scalarMemRefType =
347         MemRefType::get({}, unrankedMemRefType.getElementType());
348     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
349 
350     // Extract pointer to the underlying ranked descriptor and bitcast it to a
351     // memref<element_type> descriptor pointer to minimize the number of GEP
352     // operations.
353     UnrankedMemRefDescriptor unrankedDesc(adaptor.source());
354     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
355     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
356         loc,
357         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
358                                    addressSpace),
359         underlyingRankedDesc);
360 
361     // Get pointer to offset field of memref<element_type> descriptor.
362     Type indexPtrTy = LLVM::LLVMPointerType::get(
363         getTypeConverter()->getIndexType(), addressSpace);
364     Value two = rewriter.create<LLVM::ConstantOp>(
365         loc, typeConverter->convertType(rewriter.getI32Type()),
366         rewriter.getI32IntegerAttr(2));
367     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
368         loc, indexPtrTy, scalarMemRefDescPtr,
369         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
370 
371     // The size value that we have to extract can be obtained using GEPop with
372     // `dimOp.index() + 1` index argument.
373     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
374         loc, createIndexConstant(rewriter, loc, 1), adaptor.index());
375     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
376                                                  ValueRange({idxPlusOne}));
377     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
378   }
379 
380   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
381     if (Optional<int64_t> idx = dimOp.getConstantIndex())
382       return idx;
383 
384     if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>())
385       return constantOp.getValue()
386           .cast<IntegerAttr>()
387           .getValue()
388           .getSExtValue();
389 
390     return llvm::None;
391   }
392 
393   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
394                                   OpAdaptor adaptor,
395                                   ConversionPatternRewriter &rewriter) const {
396     Location loc = dimOp.getLoc();
397 
398     // Take advantage if index is constant.
399     MemRefType memRefType = operandType.cast<MemRefType>();
400     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
401       int64_t i = index.getValue();
402       if (memRefType.isDynamicDim(i)) {
403         // extract dynamic size from the memref descriptor.
404         MemRefDescriptor descriptor(adaptor.source());
405         return descriptor.size(rewriter, loc, i);
406       }
407       // Use constant for static size.
408       int64_t dimSize = memRefType.getDimSize(i);
409       return createIndexConstant(rewriter, loc, dimSize);
410     }
411     Value index = adaptor.index();
412     int64_t rank = memRefType.getRank();
413     MemRefDescriptor memrefDescriptor(adaptor.source());
414     return memrefDescriptor.size(rewriter, loc, index, rank);
415   }
416 };
417 
418 /// Common base for load and store operations on MemRefs. Restricts the match
419 /// to supported MemRef types. Provides functionality to emit code accessing a
420 /// specific element of the underlying data buffer.
421 template <typename Derived>
422 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
423   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
424   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
425   using Base = LoadStoreOpLowering<Derived>;
426 
427   LogicalResult match(Derived op) const override {
428     MemRefType type = op.getMemRefType();
429     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
430   }
431 };
432 
433 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
434 /// retried until it succeeds in atomically storing a new value into memory.
435 ///
436 ///      +---------------------------------+
437 ///      |   <code before the AtomicRMWOp> |
438 ///      |   <compute initial %loaded>     |
439 ///      |   cf.br loop(%loaded)              |
440 ///      +---------------------------------+
441 ///             |
442 ///  -------|   |
443 ///  |      v   v
444 ///  |   +--------------------------------+
445 ///  |   | loop(%loaded):                 |
446 ///  |   |   <body contents>              |
447 ///  |   |   %pair = cmpxchg              |
448 ///  |   |   %ok = %pair[0]               |
449 ///  |   |   %new = %pair[1]              |
450 ///  |   |   cf.cond_br %ok, end, loop(%new) |
451 ///  |   +--------------------------------+
452 ///  |          |        |
453 ///  |-----------        |
454 ///                      v
455 ///      +--------------------------------+
456 ///      | end:                           |
457 ///      |   <code after the AtomicRMWOp> |
458 ///      +--------------------------------+
459 ///
460 struct GenericAtomicRMWOpLowering
461     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
462   using Base::Base;
463 
464   LogicalResult
465   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
466                   ConversionPatternRewriter &rewriter) const override {
467     auto loc = atomicOp.getLoc();
468     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
469 
470     // Split the block into initial, loop, and ending parts.
471     auto *initBlock = rewriter.getInsertionBlock();
472     auto *loopBlock = rewriter.createBlock(
473         initBlock->getParent(), std::next(Region::iterator(initBlock)),
474         valueType, loc);
475     auto *endBlock = rewriter.createBlock(
476         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
477 
478     // Operations range to be moved to `endBlock`.
479     auto opsToMoveStart = atomicOp->getIterator();
480     auto opsToMoveEnd = initBlock->back().getIterator();
481 
482     // Compute the loaded value and branch to the loop block.
483     rewriter.setInsertionPointToEnd(initBlock);
484     auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
485     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
486                                         adaptor.indices(), rewriter);
487     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
488     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
489 
490     // Prepare the body of the loop block.
491     rewriter.setInsertionPointToStart(loopBlock);
492 
493     // Clone the GenericAtomicRMWOp region and extract the result.
494     auto loopArgument = loopBlock->getArgument(0);
495     BlockAndValueMapping mapping;
496     mapping.map(atomicOp.getCurrentValue(), loopArgument);
497     Block &entryBlock = atomicOp.body().front();
498     for (auto &nestedOp : entryBlock.without_terminator()) {
499       Operation *clone = rewriter.clone(nestedOp, mapping);
500       mapping.map(nestedOp.getResults(), clone->getResults());
501     }
502     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
503 
504     // Prepare the epilog of the loop block.
505     // Append the cmpxchg op to the end of the loop block.
506     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
507     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
508     auto boolType = IntegerType::get(rewriter.getContext(), 1);
509     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
510                                                      {valueType, boolType});
511     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
512         loc, pairType, dataPtr, loopArgument, result, successOrdering,
513         failureOrdering);
514     // Extract the %new_loaded and %ok values from the pair.
515     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
516         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
517     Value ok = rewriter.create<LLVM::ExtractValueOp>(
518         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
519 
520     // Conditionally branch to the end or back to the loop depending on %ok.
521     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
522                                     loopBlock, newLoaded);
523 
524     rewriter.setInsertionPointToEnd(endBlock);
525     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
526                  std::next(opsToMoveEnd), rewriter);
527 
528     // The 'result' of the atomic_rmw op is the newly loaded value.
529     rewriter.replaceOp(atomicOp, {newLoaded});
530 
531     return success();
532   }
533 
534 private:
535   // Clones a segment of ops [start, end) and erases the original.
536   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
537                     Block::iterator start, Block::iterator end,
538                     ConversionPatternRewriter &rewriter) const {
539     BlockAndValueMapping mapping;
540     mapping.map(oldResult, newResult);
541     SmallVector<Operation *, 2> opsToErase;
542     for (auto it = start; it != end; ++it) {
543       rewriter.clone(*it, mapping);
544       opsToErase.push_back(&*it);
545     }
546     for (auto *it : opsToErase)
547       rewriter.eraseOp(it);
548   }
549 };
550 
551 /// Returns the LLVM type of the global variable given the memref type `type`.
552 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
553                                           LLVMTypeConverter &typeConverter) {
554   // LLVM type for a global memref will be a multi-dimension array. For
555   // declarations or uninitialized global memrefs, we can potentially flatten
556   // this to a 1D array. However, for memref.global's with an initial value,
557   // we do not intend to flatten the ElementsAttribute when going from std ->
558   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
559   Type elementType = typeConverter.convertType(type.getElementType());
560   Type arrayTy = elementType;
561   // Shape has the outermost dim at index 0, so need to walk it backwards
562   for (int64_t dim : llvm::reverse(type.getShape()))
563     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
564   return arrayTy;
565 }
566 
567 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
568 struct GlobalMemrefOpLowering
569     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
570   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
571 
572   LogicalResult
573   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
574                   ConversionPatternRewriter &rewriter) const override {
575     MemRefType type = global.type();
576     if (!isConvertibleAndHasIdentityMaps(type))
577       return failure();
578 
579     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
580 
581     LLVM::Linkage linkage =
582         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
583 
584     Attribute initialValue = nullptr;
585     if (!global.isExternal() && !global.isUninitialized()) {
586       auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
587       initialValue = elementsAttr;
588 
589       // For scalar memrefs, the global variable created is of the element type,
590       // so unpack the elements attribute to extract the value.
591       if (type.getRank() == 0)
592         initialValue = elementsAttr.getSplatValue<Attribute>();
593     }
594 
595     uint64_t alignment = global.alignment().getValueOr(0);
596 
597     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
598         global, arrayTy, global.constant(), linkage, global.sym_name(),
599         initialValue, alignment, type.getMemorySpaceAsInt());
600     if (!global.isExternal() && global.isUninitialized()) {
601       Block *blk = new Block();
602       newGlobal.getInitializerRegion().push_back(blk);
603       rewriter.setInsertionPointToStart(blk);
604       Value undef[] = {
605           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
606       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
607     }
608     return success();
609   }
610 };
611 
612 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
613 /// the first element stashed into the descriptor. This reuses
614 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
615 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
616   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
617       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
618                                 converter) {}
619 
620   /// Buffer "allocation" for memref.get_global op is getting the address of
621   /// the global variable referenced.
622   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
623                                           Location loc, Value sizeBytes,
624                                           Operation *op) const override {
625     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
626     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
627     unsigned memSpace = type.getMemorySpaceAsInt();
628 
629     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
630     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
631         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
632 
633     // Get the address of the first element in the array by creating a GEP with
634     // the address of the GV as the base, and (rank + 1) number of 0 indices.
635     Type elementType = typeConverter->convertType(type.getElementType());
636     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
637 
638     SmallVector<Value> operands;
639     operands.insert(operands.end(), type.getRank() + 1,
640                     createIndexConstant(rewriter, loc, 0));
641     auto gep =
642         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
643 
644     // We do not expect the memref obtained using `memref.get_global` to be
645     // ever deallocated. Set the allocated pointer to be known bad value to
646     // help debug if that ever happens.
647     auto intPtrType = getIntPtrType(memSpace);
648     Value deadBeefConst =
649         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
650     auto deadBeefPtr =
651         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
652 
653     // Both allocated and aligned pointers are same. We could potentially stash
654     // a nullptr for the allocated pointer since we do not expect any dealloc.
655     return std::make_tuple(deadBeefPtr, gep);
656   }
657 };
658 
659 // Load operation is lowered to obtaining a pointer to the indexed element
660 // and loading it.
661 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
662   using Base::Base;
663 
664   LogicalResult
665   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
666                   ConversionPatternRewriter &rewriter) const override {
667     auto type = loadOp.getMemRefType();
668 
669     Value dataPtr = getStridedElementPtr(
670         loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter);
671     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
672     return success();
673   }
674 };
675 
676 // Store operation is lowered to obtaining a pointer to the indexed element,
677 // and storing the given value to it.
678 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
679   using Base::Base;
680 
681   LogicalResult
682   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
683                   ConversionPatternRewriter &rewriter) const override {
684     auto type = op.getMemRefType();
685 
686     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(),
687                                          adaptor.indices(), rewriter);
688     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr);
689     return success();
690   }
691 };
692 
693 // The prefetch operation is lowered in a way similar to the load operation
694 // except that the llvm.prefetch operation is used for replacement.
695 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
696   using Base::Base;
697 
698   LogicalResult
699   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
700                   ConversionPatternRewriter &rewriter) const override {
701     auto type = prefetchOp.getMemRefType();
702     auto loc = prefetchOp.getLoc();
703 
704     Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(),
705                                          adaptor.indices(), rewriter);
706 
707     // Replace with llvm.prefetch.
708     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
709     auto isWrite = rewriter.create<LLVM::ConstantOp>(
710         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
711     auto localityHint = rewriter.create<LLVM::ConstantOp>(
712         loc, llvmI32Type,
713         rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
714     auto isData = rewriter.create<LLVM::ConstantOp>(
715         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
716 
717     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
718                                                 localityHint, isData);
719     return success();
720   }
721 };
722 
723 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
724   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
725 
726   LogicalResult
727   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
728                   ConversionPatternRewriter &rewriter) const override {
729     Location loc = op.getLoc();
730     Type operandType = op.memref().getType();
731     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
732       UnrankedMemRefDescriptor desc(adaptor.memref());
733       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
734       return success();
735     }
736     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
737       rewriter.replaceOp(
738           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
739       return success();
740     }
741     return failure();
742   }
743 };
744 
745 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
746   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
747 
748   LogicalResult match(memref::CastOp memRefCastOp) const override {
749     Type srcType = memRefCastOp.getOperand().getType();
750     Type dstType = memRefCastOp.getType();
751 
752     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
753     // used for type erasure. For now they must preserve underlying element type
754     // and require source and result type to have the same rank. Therefore,
755     // perform a sanity check that the underlying structs are the same. Once op
756     // semantics are relaxed we can revisit.
757     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
758       return success(typeConverter->convertType(srcType) ==
759                      typeConverter->convertType(dstType));
760 
761     // At least one of the operands is unranked type
762     assert(srcType.isa<UnrankedMemRefType>() ||
763            dstType.isa<UnrankedMemRefType>());
764 
765     // Unranked to unranked cast is disallowed
766     return !(srcType.isa<UnrankedMemRefType>() &&
767              dstType.isa<UnrankedMemRefType>())
768                ? success()
769                : failure();
770   }
771 
772   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
773                ConversionPatternRewriter &rewriter) const override {
774     auto srcType = memRefCastOp.getOperand().getType();
775     auto dstType = memRefCastOp.getType();
776     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
777     auto loc = memRefCastOp.getLoc();
778 
779     // For ranked/ranked case, just keep the original descriptor.
780     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
781       return rewriter.replaceOp(memRefCastOp, {adaptor.source()});
782 
783     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
784       // Casting ranked to unranked memref type
785       // Set the rank in the destination from the memref type
786       // Allocate space on the stack and copy the src memref descriptor
787       // Set the ptr in the destination to the stack space
788       auto srcMemRefType = srcType.cast<MemRefType>();
789       int64_t rank = srcMemRefType.getRank();
790       // ptr = AllocaOp sizeof(MemRefDescriptor)
791       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
792           loc, adaptor.source(), rewriter);
793       // voidptr = BitCastOp srcType* to void*
794       auto voidPtr =
795           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
796               .getResult();
797       // rank = ConstantOp srcRank
798       auto rankVal = rewriter.create<LLVM::ConstantOp>(
799           loc, getIndexType(), rewriter.getIndexAttr(rank));
800       // undef = UndefOp
801       UnrankedMemRefDescriptor memRefDesc =
802           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
803       // d1 = InsertValueOp undef, rank, 0
804       memRefDesc.setRank(rewriter, loc, rankVal);
805       // d2 = InsertValueOp d1, voidptr, 1
806       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
807       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
808 
809     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
810       // Casting from unranked type to ranked.
811       // The operation is assumed to be doing a correct cast. If the destination
812       // type mismatches the unranked the type, it is undefined behavior.
813       UnrankedMemRefDescriptor memRefDesc(adaptor.source());
814       // ptr = ExtractValueOp src, 1
815       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
816       // castPtr = BitCastOp i8* to structTy*
817       auto castPtr =
818           rewriter
819               .create<LLVM::BitcastOp>(
820                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
821               .getResult();
822       // struct = LoadOp castPtr
823       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
824       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
825     } else {
826       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
827     }
828   }
829 };
830 
831 /// Pattern to lower a `memref.copy` to llvm.
832 ///
833 /// For memrefs with identity layouts, the copy is lowered to the llvm
834 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
835 /// to the generic `MemrefCopyFn`.
836 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
837   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
838 
839   LogicalResult
840   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
841                           ConversionPatternRewriter &rewriter) const {
842     auto loc = op.getLoc();
843     auto srcType = op.source().getType().dyn_cast<MemRefType>();
844 
845     MemRefDescriptor srcDesc(adaptor.source());
846 
847     // Compute number of elements.
848     Value numElements = rewriter.create<LLVM::ConstantOp>(
849         loc, getIndexType(), rewriter.getIndexAttr(1));
850     for (int pos = 0; pos < srcType.getRank(); ++pos) {
851       auto size = srcDesc.size(rewriter, loc, pos);
852       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
853     }
854 
855     // Get element size.
856     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
857     // Compute total.
858     Value totalSize =
859         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
860 
861     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
862     Value srcOffset = srcDesc.offset(rewriter, loc);
863     Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
864                                                 srcBasePtr, srcOffset);
865     MemRefDescriptor targetDesc(adaptor.target());
866     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
867     Value targetOffset = targetDesc.offset(rewriter, loc);
868     Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
869                                                    targetBasePtr, targetOffset);
870     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
871         loc, typeConverter->convertType(rewriter.getI1Type()),
872         rewriter.getBoolAttr(false));
873     rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
874                                     isVolatile);
875     rewriter.eraseOp(op);
876 
877     return success();
878   }
879 
880   LogicalResult
881   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
882                              ConversionPatternRewriter &rewriter) const {
883     auto loc = op.getLoc();
884     auto srcType = op.source().getType().cast<BaseMemRefType>();
885     auto targetType = op.target().getType().cast<BaseMemRefType>();
886 
887     // First make sure we have an unranked memref descriptor representation.
888     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
889       auto rank = rewriter.create<LLVM::ConstantOp>(
890           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
891       auto *typeConverter = getTypeConverter();
892       auto ptr =
893           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
894       auto voidPtr =
895           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
896               .getResult();
897       auto unrankedType =
898           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
899       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
900                                             unrankedType,
901                                             ValueRange{rank, voidPtr});
902     };
903 
904     Value unrankedSource = srcType.hasRank()
905                                ? makeUnranked(adaptor.source(), srcType)
906                                : adaptor.source();
907     Value unrankedTarget = targetType.hasRank()
908                                ? makeUnranked(adaptor.target(), targetType)
909                                : adaptor.target();
910 
911     // Now promote the unranked descriptors to the stack.
912     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
913                                                  rewriter.getIndexAttr(1));
914     auto promote = [&](Value desc) {
915       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
916       auto allocated =
917           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
918       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
919       return allocated;
920     };
921 
922     auto sourcePtr = promote(unrankedSource);
923     auto targetPtr = promote(unrankedTarget);
924 
925     unsigned typeSize =
926         mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
927     auto elemSize = rewriter.create<LLVM::ConstantOp>(
928         loc, getIndexType(), rewriter.getIndexAttr(typeSize));
929     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
930         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
931     rewriter.create<LLVM::CallOp>(loc, copyFn,
932                                   ValueRange{elemSize, sourcePtr, targetPtr});
933     rewriter.eraseOp(op);
934 
935     return success();
936   }
937 
938   LogicalResult
939   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
940                   ConversionPatternRewriter &rewriter) const override {
941     auto srcType = op.source().getType().cast<BaseMemRefType>();
942     auto targetType = op.target().getType().cast<BaseMemRefType>();
943 
944     auto isContiguousMemrefType = [](BaseMemRefType type) {
945       auto memrefType = type.dyn_cast<mlir::MemRefType>();
946       // We can use memcpy for memrefs if they have an identity layout or are
947       // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
948       // special case handled by memrefCopy.
949       return memrefType &&
950              (memrefType.getLayout().isIdentity() ||
951               (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
952                isStaticShapeAndContiguousRowMajor(memrefType)));
953     };
954 
955     if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
956       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
957 
958     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
959   }
960 };
961 
962 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
963 /// memref type. In unranked case, the fields are extracted from the underlying
964 /// ranked descriptor.
965 static void extractPointersAndOffset(Location loc,
966                                      ConversionPatternRewriter &rewriter,
967                                      LLVMTypeConverter &typeConverter,
968                                      Value originalOperand,
969                                      Value convertedOperand,
970                                      Value *allocatedPtr, Value *alignedPtr,
971                                      Value *offset = nullptr) {
972   Type operandType = originalOperand.getType();
973   if (operandType.isa<MemRefType>()) {
974     MemRefDescriptor desc(convertedOperand);
975     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
976     *alignedPtr = desc.alignedPtr(rewriter, loc);
977     if (offset != nullptr)
978       *offset = desc.offset(rewriter, loc);
979     return;
980   }
981 
982   unsigned memorySpace =
983       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
984   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
985   Type llvmElementType = typeConverter.convertType(elementType);
986   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
987       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
988 
989   // Extract pointer to the underlying ranked memref descriptor and cast it to
990   // ElemType**.
991   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
992   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
993 
994   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
995       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
996   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
997       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
998   if (offset != nullptr) {
999     *offset = UnrankedMemRefDescriptor::offset(
1000         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1001   }
1002 }
1003 
1004 struct MemRefReinterpretCastOpLowering
1005     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1006   using ConvertOpToLLVMPattern<
1007       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1008 
1009   LogicalResult
1010   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1011                   ConversionPatternRewriter &rewriter) const override {
1012     Type srcType = castOp.source().getType();
1013 
1014     Value descriptor;
1015     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1016                                                adaptor, &descriptor)))
1017       return failure();
1018     rewriter.replaceOp(castOp, {descriptor});
1019     return success();
1020   }
1021 
1022 private:
1023   LogicalResult convertSourceMemRefToDescriptor(
1024       ConversionPatternRewriter &rewriter, Type srcType,
1025       memref::ReinterpretCastOp castOp,
1026       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1027     MemRefType targetMemRefType =
1028         castOp.getResult().getType().cast<MemRefType>();
1029     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1030                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1031     if (!llvmTargetDescriptorTy)
1032       return failure();
1033 
1034     // Create descriptor.
1035     Location loc = castOp.getLoc();
1036     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1037 
1038     // Set allocated and aligned pointers.
1039     Value allocatedPtr, alignedPtr;
1040     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1041                              castOp.source(), adaptor.source(), &allocatedPtr,
1042                              &alignedPtr);
1043     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1044     desc.setAlignedPtr(rewriter, loc, alignedPtr);
1045 
1046     // Set offset.
1047     if (castOp.isDynamicOffset(0))
1048       desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
1049     else
1050       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1051 
1052     // Set sizes and strides.
1053     unsigned dynSizeId = 0;
1054     unsigned dynStrideId = 0;
1055     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1056       if (castOp.isDynamicSize(i))
1057         desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
1058       else
1059         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1060 
1061       if (castOp.isDynamicStride(i))
1062         desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
1063       else
1064         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1065     }
1066     *descriptor = desc;
1067     return success();
1068   }
1069 };
1070 
1071 struct MemRefReshapeOpLowering
1072     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1073   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1074 
1075   LogicalResult
1076   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1077                   ConversionPatternRewriter &rewriter) const override {
1078     Type srcType = reshapeOp.source().getType();
1079 
1080     Value descriptor;
1081     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1082                                                adaptor, &descriptor)))
1083       return failure();
1084     rewriter.replaceOp(reshapeOp, {descriptor});
1085     return success();
1086   }
1087 
1088 private:
1089   LogicalResult
1090   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1091                                   Type srcType, memref::ReshapeOp reshapeOp,
1092                                   memref::ReshapeOp::Adaptor adaptor,
1093                                   Value *descriptor) const {
1094     // Conversion for statically-known shape args is performed via
1095     // `memref_reinterpret_cast`.
1096     auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
1097     if (shapeMemRefType.hasStaticShape())
1098       return failure();
1099 
1100     // The shape is a rank-1 tensor with unknown length.
1101     Location loc = reshapeOp.getLoc();
1102     MemRefDescriptor shapeDesc(adaptor.shape());
1103     Value resultRank = shapeDesc.size(rewriter, loc, 0);
1104 
1105     // Extract address space and element type.
1106     auto targetType =
1107         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1108     unsigned addressSpace = targetType.getMemorySpaceAsInt();
1109     Type elementType = targetType.getElementType();
1110 
1111     // Create the unranked memref descriptor that holds the ranked one. The
1112     // inner descriptor is allocated on stack.
1113     auto targetDesc = UnrankedMemRefDescriptor::undef(
1114         rewriter, loc, typeConverter->convertType(targetType));
1115     targetDesc.setRank(rewriter, loc, resultRank);
1116     SmallVector<Value, 4> sizes;
1117     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1118                                            targetDesc, sizes);
1119     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1120         loc, getVoidPtrType(), sizes.front(), llvm::None);
1121     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1122 
1123     // Extract pointers and offset from the source memref.
1124     Value allocatedPtr, alignedPtr, offset;
1125     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1126                              reshapeOp.source(), adaptor.source(),
1127                              &allocatedPtr, &alignedPtr, &offset);
1128 
1129     // Set pointers and offset.
1130     Type llvmElementType = typeConverter->convertType(elementType);
1131     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1132         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1133     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1134                                               elementPtrPtrType, allocatedPtr);
1135     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1136                                             underlyingDescPtr,
1137                                             elementPtrPtrType, alignedPtr);
1138     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1139                                         underlyingDescPtr, elementPtrPtrType,
1140                                         offset);
1141 
1142     // Use the offset pointer as base for further addressing. Copy over the new
1143     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1144     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1145         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1146         elementPtrPtrType);
1147     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1148         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1149     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1150     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1151     Value resultRankMinusOne =
1152         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1153 
1154     Block *initBlock = rewriter.getInsertionBlock();
1155     Type indexType = getTypeConverter()->getIndexType();
1156     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1157 
1158     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1159                                             {indexType, indexType}, {loc, loc});
1160 
1161     // Move the remaining initBlock ops to condBlock.
1162     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1163     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1164 
1165     rewriter.setInsertionPointToEnd(initBlock);
1166     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1167                                 condBlock);
1168     rewriter.setInsertionPointToStart(condBlock);
1169     Value indexArg = condBlock->getArgument(0);
1170     Value strideArg = condBlock->getArgument(1);
1171 
1172     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1173     Value pred = rewriter.create<LLVM::ICmpOp>(
1174         loc, IntegerType::get(rewriter.getContext(), 1),
1175         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1176 
1177     Block *bodyBlock =
1178         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1179     rewriter.setInsertionPointToStart(bodyBlock);
1180 
1181     // Copy size from shape to descriptor.
1182     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1183     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1184         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1185     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1186     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1187                                       targetSizesBase, indexArg, size);
1188 
1189     // Write stride value and compute next one.
1190     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1191                                         targetStridesBase, indexArg, strideArg);
1192     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1193 
1194     // Decrement loop counter and branch back.
1195     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1196     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1197                                 condBlock);
1198 
1199     Block *remainder =
1200         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1201 
1202     // Hook up the cond exit to the remainder.
1203     rewriter.setInsertionPointToEnd(condBlock);
1204     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1205                                     llvm::None);
1206 
1207     // Reset position to beginning of new remainder block.
1208     rewriter.setInsertionPointToStart(remainder);
1209 
1210     *descriptor = targetDesc;
1211     return success();
1212   }
1213 };
1214 
1215 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1216 /// `Value`s.
1217 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1218                                       Type &llvmIndexType,
1219                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1220   return llvm::to_vector<4>(
1221       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1222         if (auto attr = value.dyn_cast<Attribute>())
1223           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1224         return value.get<Value>();
1225       }));
1226 }
1227 
1228 /// Compute a map that for a given dimension of the expanded type gives the
1229 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1230 /// the `reassocation` maps.
1231 static DenseMap<int64_t, int64_t>
1232 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1233   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1234   for (auto &en : enumerate(reassociation)) {
1235     for (auto dim : en.value())
1236       expandedDimToCollapsedDim[dim] = en.index();
1237   }
1238   return expandedDimToCollapsedDim;
1239 }
1240 
1241 static OpFoldResult
1242 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1243                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1244                          MemRefDescriptor &inDesc,
1245                          ArrayRef<int64_t> inStaticShape,
1246                          ArrayRef<ReassociationIndices> reassocation,
1247                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1248   int64_t outDimSize = outStaticShape[outDimIndex];
1249   if (!ShapedType::isDynamic(outDimSize))
1250     return b.getIndexAttr(outDimSize);
1251 
1252   // Calculate the multiplication of all the out dim sizes except the
1253   // current dim.
1254   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1255   int64_t otherDimSizesMul = 1;
1256   for (auto otherDimIndex : reassocation[inDimIndex]) {
1257     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1258       continue;
1259     int64_t otherDimSize = outStaticShape[otherDimIndex];
1260     assert(!ShapedType::isDynamic(otherDimSize) &&
1261            "single dimension cannot be expanded into multiple dynamic "
1262            "dimensions");
1263     otherDimSizesMul *= otherDimSize;
1264   }
1265 
1266   // outDimSize = inDimSize / otherOutDimSizesMul
1267   int64_t inDimSize = inStaticShape[inDimIndex];
1268   Value inDimSizeDynamic =
1269       ShapedType::isDynamic(inDimSize)
1270           ? inDesc.size(b, loc, inDimIndex)
1271           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1272                                        b.getIndexAttr(inDimSize));
1273   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1274       loc, inDimSizeDynamic,
1275       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1276                                  b.getIndexAttr(otherDimSizesMul)));
1277   return outDimSizeDynamic;
1278 }
1279 
1280 static OpFoldResult getCollapsedOutputDimSize(
1281     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1282     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1283     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1284   if (!ShapedType::isDynamic(outDimSize))
1285     return b.getIndexAttr(outDimSize);
1286 
1287   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1288   Value outDimSizeDynamic = c1;
1289   for (auto inDimIndex : reassocation[outDimIndex]) {
1290     int64_t inDimSize = inStaticShape[inDimIndex];
1291     Value inDimSizeDynamic =
1292         ShapedType::isDynamic(inDimSize)
1293             ? inDesc.size(b, loc, inDimIndex)
1294             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1295                                          b.getIndexAttr(inDimSize));
1296     outDimSizeDynamic =
1297         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1298   }
1299   return outDimSizeDynamic;
1300 }
1301 
1302 static SmallVector<OpFoldResult, 4>
1303 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1304                         ArrayRef<ReassociationIndices> reassociation,
1305                         ArrayRef<int64_t> inStaticShape,
1306                         MemRefDescriptor &inDesc,
1307                         ArrayRef<int64_t> outStaticShape) {
1308   return llvm::to_vector<4>(llvm::map_range(
1309       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1310         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1311                                          outStaticShape[outDimIndex],
1312                                          inStaticShape, inDesc, reassociation);
1313       }));
1314 }
1315 
1316 static SmallVector<OpFoldResult, 4>
1317 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1318                        ArrayRef<ReassociationIndices> reassociation,
1319                        ArrayRef<int64_t> inStaticShape,
1320                        MemRefDescriptor &inDesc,
1321                        ArrayRef<int64_t> outStaticShape) {
1322   DenseMap<int64_t, int64_t> outDimToInDimMap =
1323       getExpandedDimToCollapsedDimMap(reassociation);
1324   return llvm::to_vector<4>(llvm::map_range(
1325       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1326         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1327                                         outStaticShape, inDesc, inStaticShape,
1328                                         reassociation, outDimToInDimMap);
1329       }));
1330 }
1331 
1332 static SmallVector<Value>
1333 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1334                       ArrayRef<ReassociationIndices> reassociation,
1335                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1336                       ArrayRef<int64_t> outStaticShape) {
1337   return outStaticShape.size() < inStaticShape.size()
1338              ? getAsValues(b, loc, llvmIndexType,
1339                            getCollapsedOutputShape(b, loc, llvmIndexType,
1340                                                    reassociation, inStaticShape,
1341                                                    inDesc, outStaticShape))
1342              : getAsValues(b, loc, llvmIndexType,
1343                            getExpandedOutputShape(b, loc, llvmIndexType,
1344                                                   reassociation, inStaticShape,
1345                                                   inDesc, outStaticShape));
1346 }
1347 
1348 static void fillInStridesForExpandedMemDescriptor(
1349     OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
1350     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1351   // See comments for computeExpandedLayoutMap for details on how the strides
1352   // are calculated.
1353   for (auto &en : llvm::enumerate(reassociation)) {
1354     auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
1355     for (auto dstIndex : llvm::reverse(en.value())) {
1356       dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
1357       Value size = dstDesc.size(b, loc, dstIndex);
1358       currentStrideToExpand =
1359           b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
1360     }
1361   }
1362 }
1363 
1364 static void fillInStridesForCollapsedMemDescriptor(
1365     ConversionPatternRewriter &rewriter, Location loc, Operation *op,
1366     TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
1367     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1368   // See comments for computeCollapsedLayoutMap for details on how the strides
1369   // are calculated.
1370   auto srcShape = srcType.getShape();
1371   for (auto &en : llvm::enumerate(reassociation)) {
1372     rewriter.setInsertionPoint(op);
1373     auto dstIndex = en.index();
1374     ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value());
1375     while (srcShape[ref.back()] == 1 && ref.size() > 1)
1376       ref = ref.drop_back();
1377     if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1378       dstDesc.setStride(rewriter, loc, dstIndex,
1379                         srcDesc.stride(rewriter, loc, ref.back()));
1380     } else {
1381       // Iterate over the source strides in reverse order. Skip over the
1382       // dimensions whose size is 1.
1383       // TODO: we should take the minimum stride in the reassociation group
1384       // instead of just the first where the dimension is not 1.
1385       //
1386       // +------------------------------------------------------+
1387       // | curEntry:                                            |
1388       // |   %srcStride = strides[srcIndex]                     |
1389       // |   %neOne = cmp sizes[srcIndex],1                     +--+
1390       // |   cf.cond_br %neOne, continue(%srcStride), nextEntry |  |
1391       // +-------------------------+----------------------------+  |
1392       //                           |                               |
1393       //                           v                               |
1394       //            +-----------------------------+                |
1395       //            | nextEntry:                  |                |
1396       //            |   ...                       +---+            |
1397       //            +--------------+--------------+   |            |
1398       //                           |                  |            |
1399       //                           v                  |            |
1400       //            +-----------------------------+   |            |
1401       //            | nextEntry:                  |   |            |
1402       //            |   ...                       |   |            |
1403       //            +--------------+--------------+   |   +--------+
1404       //                           |                  |   |
1405       //                           v                  v   v
1406       //   +--------------------------------------------------+
1407       //   | continue(%newStride):                            |
1408       //   |   %newMemRefDes = setStride(%newStride,dstIndex) |
1409       //   +--------------------------------------------------+
1410       OpBuilder::InsertionGuard guard(rewriter);
1411       Block *initBlock = rewriter.getInsertionBlock();
1412       Block *continueBlock =
1413           rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
1414       continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
1415       rewriter.setInsertionPointToStart(continueBlock);
1416       dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
1417 
1418       Block *curEntryBlock = initBlock;
1419       Block *nextEntryBlock;
1420       for (auto srcIndex : llvm::reverse(ref)) {
1421         if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
1422           continue;
1423         rewriter.setInsertionPointToEnd(curEntryBlock);
1424         Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
1425         if (srcIndex == ref.front()) {
1426           rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
1427           break;
1428         }
1429         Value one = rewriter.create<LLVM::ConstantOp>(
1430             loc, typeConverter->convertType(rewriter.getI64Type()),
1431             rewriter.getI32IntegerAttr(1));
1432         Value predNeOne = rewriter.create<LLVM::ICmpOp>(
1433             loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
1434             one);
1435         {
1436           OpBuilder::InsertionGuard guard(rewriter);
1437           nextEntryBlock = rewriter.createBlock(
1438               initBlock->getParent(), Region::iterator(continueBlock), {});
1439         }
1440         rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
1441                                         srcStride, nextEntryBlock, llvm::None);
1442         curEntryBlock = nextEntryBlock;
1443       }
1444     }
1445   }
1446 }
1447 
1448 static void fillInDynamicStridesForMemDescriptor(
1449     ConversionPatternRewriter &b, Location loc, Operation *op,
1450     TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
1451     MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
1452     ArrayRef<ReassociationIndices> reassociation) {
1453   if (srcType.getRank() > dstType.getRank())
1454     fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
1455                                            srcDesc, dstDesc, reassociation);
1456   else
1457     fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
1458                                           reassociation);
1459 }
1460 
1461 // ReshapeOp creates a new view descriptor of the proper rank.
1462 // For now, the only conversion supported is for target MemRef with static sizes
1463 // and strides.
1464 template <typename ReshapeOp>
1465 class ReassociatingReshapeOpConversion
1466     : public ConvertOpToLLVMPattern<ReshapeOp> {
1467 public:
1468   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1469   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1470 
1471   LogicalResult
1472   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1473                   ConversionPatternRewriter &rewriter) const override {
1474     MemRefType dstType = reshapeOp.getResultType();
1475     MemRefType srcType = reshapeOp.getSrcType();
1476 
1477     int64_t offset;
1478     SmallVector<int64_t, 4> strides;
1479     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1480       return rewriter.notifyMatchFailure(
1481           reshapeOp, "failed to get stride and offset exprs");
1482     }
1483 
1484     MemRefDescriptor srcDesc(adaptor.src());
1485     Location loc = reshapeOp->getLoc();
1486     auto dstDesc = MemRefDescriptor::undef(
1487         rewriter, loc, this->typeConverter->convertType(dstType));
1488     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1489     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1490     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1491 
1492     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1493     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1494     Type llvmIndexType =
1495         this->typeConverter->convertType(rewriter.getIndexType());
1496     SmallVector<Value> dstShape = getDynamicOutputShape(
1497         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1498         srcStaticShape, srcDesc, dstStaticShape);
1499     for (auto &en : llvm::enumerate(dstShape))
1500       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1501 
1502     auto isStaticStride = [](int64_t stride) {
1503       return !ShapedType::isDynamicStrideOrOffset(stride);
1504     };
1505     if (llvm::all_of(strides, isStaticStride)) {
1506       for (auto &en : llvm::enumerate(strides))
1507         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1508     } else if (srcType.getLayout().isIdentity() &&
1509                dstType.getLayout().isIdentity()) {
1510       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1511                                                    rewriter.getIndexAttr(1));
1512       Value stride = c1;
1513       for (auto dimIndex :
1514            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1515         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1516         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1517       }
1518     } else {
1519       // There could be mixed static/dynamic strides. For simplicity, we
1520       // recompute all strides if there is at least one dynamic stride.
1521       fillInDynamicStridesForMemDescriptor(
1522           rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
1523           srcDesc, dstDesc, reshapeOp.getReassociationIndices());
1524     }
1525     rewriter.replaceOp(reshapeOp, {dstDesc});
1526     return success();
1527   }
1528 };
1529 
1530 /// Conversion pattern that transforms a subview op into:
1531 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1532 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1533 ///      and stride.
1534 /// The subview op is replaced by the descriptor.
1535 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1536   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1537 
1538   LogicalResult
1539   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1540                   ConversionPatternRewriter &rewriter) const override {
1541     auto loc = subViewOp.getLoc();
1542 
1543     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
1544     auto sourceElementTy =
1545         typeConverter->convertType(sourceMemRefType.getElementType());
1546 
1547     auto viewMemRefType = subViewOp.getType();
1548     auto inferredType = memref::SubViewOp::inferResultType(
1549                             subViewOp.getSourceType(),
1550                             extractFromI64ArrayAttr(subViewOp.static_offsets()),
1551                             extractFromI64ArrayAttr(subViewOp.static_sizes()),
1552                             extractFromI64ArrayAttr(subViewOp.static_strides()))
1553                             .cast<MemRefType>();
1554     auto targetElementTy =
1555         typeConverter->convertType(viewMemRefType.getElementType());
1556     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1557     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1558         !LLVM::isCompatibleType(sourceElementTy) ||
1559         !LLVM::isCompatibleType(targetElementTy) ||
1560         !LLVM::isCompatibleType(targetDescTy))
1561       return failure();
1562 
1563     // Extract the offset and strides from the type.
1564     int64_t offset;
1565     SmallVector<int64_t, 4> strides;
1566     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1567     if (failed(successStrides))
1568       return failure();
1569 
1570     // Create the descriptor.
1571     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1572       return failure();
1573     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1574     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1575 
1576     // Copy the buffer pointer from the old descriptor to the new one.
1577     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1578     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1579         loc,
1580         LLVM::LLVMPointerType::get(targetElementTy,
1581                                    viewMemRefType.getMemorySpaceAsInt()),
1582         extracted);
1583     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1584 
1585     // Copy the aligned pointer from the old descriptor to the new one.
1586     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1587     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1588         loc,
1589         LLVM::LLVMPointerType::get(targetElementTy,
1590                                    viewMemRefType.getMemorySpaceAsInt()),
1591         extracted);
1592     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1593 
1594     size_t inferredShapeRank = inferredType.getRank();
1595     size_t resultShapeRank = viewMemRefType.getRank();
1596 
1597     // Extract strides needed to compute offset.
1598     SmallVector<Value, 4> strideValues;
1599     strideValues.reserve(inferredShapeRank);
1600     for (unsigned i = 0; i < inferredShapeRank; ++i)
1601       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1602 
1603     // Offset.
1604     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1605     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1606       targetMemRef.setConstantOffset(rewriter, loc, offset);
1607     } else {
1608       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1609       // `inferredShapeRank` may be larger than the number of offset operands
1610       // because of trailing semantics. In this case, the offset is guaranteed
1611       // to be interpreted as 0 and we can just skip the extra dimensions.
1612       for (unsigned i = 0, e = std::min(inferredShapeRank,
1613                                         subViewOp.getMixedOffsets().size());
1614            i < e; ++i) {
1615         Value offset =
1616             // TODO: need OpFoldResult ODS adaptor to clean this up.
1617             subViewOp.isDynamicOffset(i)
1618                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1619                 : rewriter.create<LLVM::ConstantOp>(
1620                       loc, llvmIndexType,
1621                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1622         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1623         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1624       }
1625       targetMemRef.setOffset(rewriter, loc, baseOffset);
1626     }
1627 
1628     // Update sizes and strides.
1629     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1630     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1631     assert(mixedSizes.size() == mixedStrides.size() &&
1632            "expected sizes and strides of equal length");
1633     llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1634     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1635          i >= 0 && j >= 0; --i) {
1636       if (unusedDims.test(i))
1637         continue;
1638 
1639       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1640       // In this case, the size is guaranteed to be interpreted as Dim and the
1641       // stride as 1.
1642       Value size, stride;
1643       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1644         // If the static size is available, use it directly. This is similar to
1645         // the folding of dim(constant-op) but removes the need for dim to be
1646         // aware of LLVM constants and for this pass to be aware of std
1647         // constants.
1648         int64_t staticSize =
1649             subViewOp.source().getType().cast<MemRefType>().getShape()[i];
1650         if (staticSize != ShapedType::kDynamicSize) {
1651           size = rewriter.create<LLVM::ConstantOp>(
1652               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1653         } else {
1654           Value pos = rewriter.create<LLVM::ConstantOp>(
1655               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1656           Value dim =
1657               rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos);
1658           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1659               loc, llvmIndexType, dim);
1660           size = cast.getResult(0);
1661         }
1662         stride = rewriter.create<LLVM::ConstantOp>(
1663             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1664       } else {
1665         // TODO: need OpFoldResult ODS adaptor to clean this up.
1666         size =
1667             subViewOp.isDynamicSize(i)
1668                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1669                 : rewriter.create<LLVM::ConstantOp>(
1670                       loc, llvmIndexType,
1671                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1672         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1673           stride = rewriter.create<LLVM::ConstantOp>(
1674               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1675         } else {
1676           stride =
1677               subViewOp.isDynamicStride(i)
1678                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1679                   : rewriter.create<LLVM::ConstantOp>(
1680                         loc, llvmIndexType,
1681                         rewriter.getI64IntegerAttr(
1682                             subViewOp.getStaticStride(i)));
1683           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1684         }
1685       }
1686       targetMemRef.setSize(rewriter, loc, j, size);
1687       targetMemRef.setStride(rewriter, loc, j, stride);
1688       j--;
1689     }
1690 
1691     rewriter.replaceOp(subViewOp, {targetMemRef});
1692     return success();
1693   }
1694 };
1695 
1696 /// Conversion pattern that transforms a transpose op into:
1697 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1698 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1699 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1700 ///      and stride. Size and stride are permutations of the original values.
1701 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1702 /// The transpose op is replaced by the alloca'ed pointer.
1703 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1704 public:
1705   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1706 
1707   LogicalResult
1708   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1709                   ConversionPatternRewriter &rewriter) const override {
1710     auto loc = transposeOp.getLoc();
1711     MemRefDescriptor viewMemRef(adaptor.in());
1712 
1713     // No permutation, early exit.
1714     if (transposeOp.permutation().isIdentity())
1715       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1716 
1717     auto targetMemRef = MemRefDescriptor::undef(
1718         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1719 
1720     // Copy the base and aligned pointers from the old descriptor to the new
1721     // one.
1722     targetMemRef.setAllocatedPtr(rewriter, loc,
1723                                  viewMemRef.allocatedPtr(rewriter, loc));
1724     targetMemRef.setAlignedPtr(rewriter, loc,
1725                                viewMemRef.alignedPtr(rewriter, loc));
1726 
1727     // Copy the offset pointer from the old descriptor to the new one.
1728     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1729 
1730     // Iterate over the dimensions and apply size/stride permutation.
1731     for (const auto &en :
1732          llvm::enumerate(transposeOp.permutation().getResults())) {
1733       int sourcePos = en.index();
1734       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1735       targetMemRef.setSize(rewriter, loc, targetPos,
1736                            viewMemRef.size(rewriter, loc, sourcePos));
1737       targetMemRef.setStride(rewriter, loc, targetPos,
1738                              viewMemRef.stride(rewriter, loc, sourcePos));
1739     }
1740 
1741     rewriter.replaceOp(transposeOp, {targetMemRef});
1742     return success();
1743   }
1744 };
1745 
1746 /// Conversion pattern that transforms an op into:
1747 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1748 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1749 ///      and stride.
1750 /// The view op is replaced by the descriptor.
1751 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1752   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1753 
1754   // Build and return the value for the idx^th shape dimension, either by
1755   // returning the constant shape dimension or counting the proper dynamic size.
1756   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1757                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1758                 unsigned idx) const {
1759     assert(idx < shape.size());
1760     if (!ShapedType::isDynamic(shape[idx]))
1761       return createIndexConstant(rewriter, loc, shape[idx]);
1762     // Count the number of dynamic dims in range [0, idx]
1763     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1764       return ShapedType::isDynamic(v);
1765     });
1766     return dynamicSizes[nDynamic];
1767   }
1768 
1769   // Build and return the idx^th stride, either by returning the constant stride
1770   // or by computing the dynamic stride from the current `runningStride` and
1771   // `nextSize`. The caller should keep a running stride and update it with the
1772   // result returned by this function.
1773   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1774                   ArrayRef<int64_t> strides, Value nextSize,
1775                   Value runningStride, unsigned idx) const {
1776     assert(idx < strides.size());
1777     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1778       return createIndexConstant(rewriter, loc, strides[idx]);
1779     if (nextSize)
1780       return runningStride
1781                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1782                  : nextSize;
1783     assert(!runningStride);
1784     return createIndexConstant(rewriter, loc, 1);
1785   }
1786 
1787   LogicalResult
1788   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1789                   ConversionPatternRewriter &rewriter) const override {
1790     auto loc = viewOp.getLoc();
1791 
1792     auto viewMemRefType = viewOp.getType();
1793     auto targetElementTy =
1794         typeConverter->convertType(viewMemRefType.getElementType());
1795     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1796     if (!targetDescTy || !targetElementTy ||
1797         !LLVM::isCompatibleType(targetElementTy) ||
1798         !LLVM::isCompatibleType(targetDescTy))
1799       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1800              failure();
1801 
1802     int64_t offset;
1803     SmallVector<int64_t, 4> strides;
1804     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1805     if (failed(successStrides))
1806       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1807     assert(offset == 0 && "expected offset to be 0");
1808 
1809     // Create the descriptor.
1810     MemRefDescriptor sourceMemRef(adaptor.source());
1811     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1812 
1813     // Field 1: Copy the allocated pointer, used for malloc/free.
1814     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1815     auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
1816     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1817         loc,
1818         LLVM::LLVMPointerType::get(targetElementTy,
1819                                    srcMemRefType.getMemorySpaceAsInt()),
1820         allocatedPtr);
1821     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1822 
1823     // Field 2: Copy the actual aligned pointer to payload.
1824     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1825     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
1826                                               alignedPtr, adaptor.byte_shift());
1827     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1828         loc,
1829         LLVM::LLVMPointerType::get(targetElementTy,
1830                                    srcMemRefType.getMemorySpaceAsInt()),
1831         alignedPtr);
1832     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1833 
1834     // Field 3: The offset in the resulting type must be 0. This is because of
1835     // the type change: an offset on srcType* may not be expressible as an
1836     // offset on dstType*.
1837     targetMemRef.setOffset(rewriter, loc,
1838                            createIndexConstant(rewriter, loc, offset));
1839 
1840     // Early exit for 0-D corner case.
1841     if (viewMemRefType.getRank() == 0)
1842       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1843 
1844     // Fields 4 and 5: Update sizes and strides.
1845     if (strides.back() != 1)
1846       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1847              failure();
1848     Value stride = nullptr, nextSize = nullptr;
1849     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1850       // Update size.
1851       Value size =
1852           getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
1853       targetMemRef.setSize(rewriter, loc, i, size);
1854       // Update stride.
1855       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1856       targetMemRef.setStride(rewriter, loc, i, stride);
1857       nextSize = size;
1858     }
1859 
1860     rewriter.replaceOp(viewOp, {targetMemRef});
1861     return success();
1862   }
1863 };
1864 
1865 //===----------------------------------------------------------------------===//
1866 // AtomicRMWOpLowering
1867 //===----------------------------------------------------------------------===//
1868 
1869 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1870 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1871 static Optional<LLVM::AtomicBinOp>
1872 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1873   switch (atomicOp.kind()) {
1874   case arith::AtomicRMWKind::addf:
1875     return LLVM::AtomicBinOp::fadd;
1876   case arith::AtomicRMWKind::addi:
1877     return LLVM::AtomicBinOp::add;
1878   case arith::AtomicRMWKind::assign:
1879     return LLVM::AtomicBinOp::xchg;
1880   case arith::AtomicRMWKind::maxs:
1881     return LLVM::AtomicBinOp::max;
1882   case arith::AtomicRMWKind::maxu:
1883     return LLVM::AtomicBinOp::umax;
1884   case arith::AtomicRMWKind::mins:
1885     return LLVM::AtomicBinOp::min;
1886   case arith::AtomicRMWKind::minu:
1887     return LLVM::AtomicBinOp::umin;
1888   case arith::AtomicRMWKind::ori:
1889     return LLVM::AtomicBinOp::_or;
1890   case arith::AtomicRMWKind::andi:
1891     return LLVM::AtomicBinOp::_and;
1892   default:
1893     return llvm::None;
1894   }
1895   llvm_unreachable("Invalid AtomicRMWKind");
1896 }
1897 
1898 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1899   using Base::Base;
1900 
1901   LogicalResult
1902   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1903                   ConversionPatternRewriter &rewriter) const override {
1904     if (failed(match(atomicOp)))
1905       return failure();
1906     auto maybeKind = matchSimpleAtomicOp(atomicOp);
1907     if (!maybeKind)
1908       return failure();
1909     auto resultType = adaptor.value().getType();
1910     auto memRefType = atomicOp.getMemRefType();
1911     auto dataPtr =
1912         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
1913                              adaptor.indices(), rewriter);
1914     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1915         atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
1916         LLVM::AtomicOrdering::acq_rel);
1917     return success();
1918   }
1919 };
1920 
1921 } // namespace
1922 
1923 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
1924                                                   RewritePatternSet &patterns) {
1925   // clang-format off
1926   patterns.add<
1927       AllocaOpLowering,
1928       AllocaScopeOpLowering,
1929       AtomicRMWOpLowering,
1930       AssumeAlignmentOpLowering,
1931       DimOpLowering,
1932       GenericAtomicRMWOpLowering,
1933       GlobalMemrefOpLowering,
1934       GetGlobalMemrefOpLowering,
1935       LoadOpLowering,
1936       MemRefCastOpLowering,
1937       MemRefCopyOpLowering,
1938       MemRefReinterpretCastOpLowering,
1939       MemRefReshapeOpLowering,
1940       PrefetchOpLowering,
1941       RankOpLowering,
1942       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1943       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1944       StoreOpLowering,
1945       SubViewOpLowering,
1946       TransposeOpLowering,
1947       ViewOpLowering>(converter);
1948   // clang-format on
1949   auto allocLowering = converter.getOptions().allocLowering;
1950   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1951     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1952   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1953     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1954 }
1955 
1956 namespace {
1957 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
1958   MemRefToLLVMPass() = default;
1959 
1960   void runOnOperation() override {
1961     Operation *op = getOperation();
1962     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1963     LowerToLLVMOptions options(&getContext(),
1964                                dataLayoutAnalysis.getAtOrAbove(op));
1965     options.allocLowering =
1966         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1967                          : LowerToLLVMOptions::AllocLowering::Malloc);
1968     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1969       options.overrideIndexBitwidth(indexBitwidth);
1970 
1971     LLVMTypeConverter typeConverter(&getContext(), options,
1972                                     &dataLayoutAnalysis);
1973     RewritePatternSet patterns(&getContext());
1974     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
1975     LLVMConversionTarget target(getContext());
1976     target.addLegalOp<func::FuncOp>();
1977     if (failed(applyPartialConversion(op, target, std::move(patterns))))
1978       signalPassFailure();
1979   }
1980 };
1981 } // namespace
1982 
1983 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
1984   return std::make_unique<MemRefToLLVMPass>();
1985 }
1986