1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 
29 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
30   return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
31 }
32 
33 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
34   AllocOpLowering(LLVMTypeConverter &converter)
35       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
36                                 converter) {}
37 
38   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
39                                           Location loc, Value sizeBytes,
40                                           Operation *op) const override {
41     // Heap allocations.
42     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
43     MemRefType memRefType = allocOp.getType();
44 
45     Value alignment;
46     if (auto alignmentAttr = allocOp.getAlignment()) {
47       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
48     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
49       // In the case where no alignment is specified, we may want to override
50       // `malloc's` behavior. `malloc` typically aligns at the size of the
51       // biggest scalar on a target HW. For non-scalars, use the natural
52       // alignment of the LLVM type given by the LLVM DataLayout.
53       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
54     }
55 
56     if (alignment) {
57       // Adjust the allocation size to consider alignment.
58       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
59     }
60 
61     // Allocate the underlying buffer and store a pointer to it in the MemRef
62     // descriptor.
63     Type elementPtrType = this->getElementPtrType(memRefType);
64     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
65         allocOp->getParentOfType<ModuleOp>(), getIndexType());
66     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
67                                   getVoidPtrType());
68     Value allocatedPtr =
69         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
70 
71     Value alignedPtr = allocatedPtr;
72     if (alignment) {
73       // Compute the aligned type pointer.
74       Value allocatedInt =
75           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
76       Value alignmentInt =
77           createAligned(rewriter, loc, allocatedInt, alignment);
78       alignedPtr =
79           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
80     }
81 
82     return std::make_tuple(allocatedPtr, alignedPtr);
83   }
84 };
85 
86 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
87   AlignedAllocOpLowering(LLVMTypeConverter &converter)
88       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
89                                 converter) {}
90 
91   /// Returns the memref's element size in bytes using the data layout active at
92   /// `op`.
93   // TODO: there are other places where this is used. Expose publicly?
94   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
95     const DataLayout *layout = &defaultLayout;
96     if (const DataLayoutAnalysis *analysis =
97             getTypeConverter()->getDataLayoutAnalysis()) {
98       layout = &analysis->getAbove(op);
99     }
100     Type elementType = memRefType.getElementType();
101     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
102       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
103                                                          *layout);
104     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
105       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
106           memRefElementType, *layout);
107     return layout->getTypeSize(elementType);
108   }
109 
110   /// Returns true if the memref size in bytes is known to be a multiple of
111   /// factor assuming the data layout active at `op`.
112   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
113                               Operation *op) const {
114     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
115     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
116       if (ShapedType::isDynamic(type.getDimSize(i)))
117         continue;
118       sizeDivisor = sizeDivisor * type.getDimSize(i);
119     }
120     return sizeDivisor % factor == 0;
121   }
122 
123   /// Returns the alignment to be used for the allocation call itself.
124   /// aligned_alloc requires the allocation size to be a power of two, and the
125   /// allocation size to be a multiple of alignment,
126   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
127     if (Optional<uint64_t> alignment = allocOp.getAlignment())
128       return *alignment;
129 
130     // Whenever we don't have alignment set, we will use an alignment
131     // consistent with the element type; since the allocation size has to be a
132     // power of two, we will bump to the next power of two if it already isn't.
133     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
134     return std::max(kMinAlignedAllocAlignment,
135                     llvm::PowerOf2Ceil(eltSizeBytes));
136   }
137 
138   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
139                                           Location loc, Value sizeBytes,
140                                           Operation *op) const override {
141     // Heap allocations.
142     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
143     MemRefType memRefType = allocOp.getType();
144     int64_t alignment = getAllocationAlignment(allocOp);
145     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
146 
147     // aligned_alloc requires size to be a multiple of alignment; we will pad
148     // the size to the next multiple if necessary.
149     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
150       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
151 
152     Type elementPtrType = this->getElementPtrType(memRefType);
153     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
154         allocOp->getParentOfType<ModuleOp>(), getIndexType());
155     auto results =
156         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
157                        getVoidPtrType());
158     Value allocatedPtr =
159         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
160 
161     return std::make_tuple(allocatedPtr, allocatedPtr);
162   }
163 
164   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
165   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
166 
167   /// Default layout to use in absence of the corresponding analysis.
168   DataLayout defaultLayout;
169 };
170 
171 // Out of line definition, required till C++17.
172 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
173 
174 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
175   AllocaOpLowering(LLVMTypeConverter &converter)
176       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
177                                 converter) {}
178 
179   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
180   /// is set to null for stack allocations. `accessAlignment` is set if
181   /// alignment is needed post allocation (for eg. in conjunction with malloc).
182   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
183                                           Location loc, Value sizeBytes,
184                                           Operation *op) const override {
185 
186     // With alloca, one gets a pointer to the element type right away.
187     // For stack allocations.
188     auto allocaOp = cast<memref::AllocaOp>(op);
189     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
190 
191     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
192         loc, elementPtrType, sizeBytes,
193         allocaOp.getAlignment() ? *allocaOp.getAlignment() : 0);
194 
195     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
196   }
197 };
198 
199 struct AllocaScopeOpLowering
200     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
201   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
202 
203   LogicalResult
204   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
205                   ConversionPatternRewriter &rewriter) const override {
206     OpBuilder::InsertionGuard guard(rewriter);
207     Location loc = allocaScopeOp.getLoc();
208 
209     // Split the current block before the AllocaScopeOp to create the inlining
210     // point.
211     auto *currentBlock = rewriter.getInsertionBlock();
212     auto *remainingOpsBlock =
213         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
214     Block *continueBlock;
215     if (allocaScopeOp.getNumResults() == 0) {
216       continueBlock = remainingOpsBlock;
217     } else {
218       continueBlock = rewriter.createBlock(
219           remainingOpsBlock, allocaScopeOp.getResultTypes(),
220           SmallVector<Location>(allocaScopeOp->getNumResults(),
221                                 allocaScopeOp.getLoc()));
222       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
223     }
224 
225     // Inline body region.
226     Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
227     Block *afterBody = &allocaScopeOp.getBodyRegion().back();
228     rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
229 
230     // Save stack and then branch into the body of the region.
231     rewriter.setInsertionPointToEnd(currentBlock);
232     auto stackSaveOp =
233         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
234     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
235 
236     // Replace the alloca_scope return with a branch that jumps out of the body.
237     // Stack restore before leaving the body region.
238     rewriter.setInsertionPointToEnd(afterBody);
239     auto returnOp =
240         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
241     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
242         returnOp, returnOp.getResults(), continueBlock);
243 
244     // Insert stack restore before jumping out the body of the region.
245     rewriter.setInsertionPoint(branchOp);
246     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
247 
248     // Replace the op with values return from the body region.
249     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
250 
251     return success();
252   }
253 };
254 
255 struct AssumeAlignmentOpLowering
256     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
257   using ConvertOpToLLVMPattern<
258       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
259 
260   LogicalResult
261   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
262                   ConversionPatternRewriter &rewriter) const override {
263     Value memref = adaptor.getMemref();
264     unsigned alignment = op.getAlignment();
265     auto loc = op.getLoc();
266 
267     MemRefDescriptor memRefDescriptor(memref);
268     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
269 
270     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
271     // the asserted memref.alignedPtr isn't used anywhere else, as the real
272     // users like load/store/views always re-extract memref.alignedPtr as they
273     // get lowered.
274     //
275     // This relies on LLVM's CSE optimization (potentially after SROA), since
276     // after CSE all memref.alignedPtr instances get de-duplicated into the same
277     // pointer SSA value.
278     auto intPtrType =
279         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
280     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
281     Value mask =
282         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
283     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
284     rewriter.create<LLVM::AssumeOp>(
285         loc, rewriter.create<LLVM::ICmpOp>(
286                  loc, LLVM::ICmpPredicate::eq,
287                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
288 
289     rewriter.eraseOp(op);
290     return success();
291   }
292 };
293 
294 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
295 // The memref descriptor being an SSA value, there is no need to clean it up
296 // in any way.
297 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
298   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
299 
300   explicit DeallocOpLowering(LLVMTypeConverter &converter)
301       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
302 
303   LogicalResult
304   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
305                   ConversionPatternRewriter &rewriter) const override {
306     // Insert the `free` declaration if it is not already present.
307     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
308     MemRefDescriptor memref(adaptor.getMemref());
309     Value casted = rewriter.create<LLVM::BitcastOp>(
310         op.getLoc(), getVoidPtrType(),
311         memref.allocatedPtr(rewriter, op.getLoc()));
312     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
313         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
314     return success();
315   }
316 };
317 
318 // A `dim` is converted to a constant for static sizes and to an access to the
319 // size stored in the memref descriptor for dynamic sizes.
320 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
321   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
322 
323   LogicalResult
324   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
325                   ConversionPatternRewriter &rewriter) const override {
326     Type operandType = dimOp.getSource().getType();
327     if (operandType.isa<UnrankedMemRefType>()) {
328       rewriter.replaceOp(
329           dimOp, {extractSizeOfUnrankedMemRef(
330                      operandType, dimOp, adaptor.getOperands(), rewriter)});
331 
332       return success();
333     }
334     if (operandType.isa<MemRefType>()) {
335       rewriter.replaceOp(
336           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
337                                             adaptor.getOperands(), rewriter)});
338       return success();
339     }
340     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
341   }
342 
343 private:
344   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
345                                     OpAdaptor adaptor,
346                                     ConversionPatternRewriter &rewriter) const {
347     Location loc = dimOp.getLoc();
348 
349     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
350     auto scalarMemRefType =
351         MemRefType::get({}, unrankedMemRefType.getElementType());
352     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
353 
354     // Extract pointer to the underlying ranked descriptor and bitcast it to a
355     // memref<element_type> descriptor pointer to minimize the number of GEP
356     // operations.
357     UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
358     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
359     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
360         loc,
361         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
362                                    addressSpace),
363         underlyingRankedDesc);
364 
365     // Get pointer to offset field of memref<element_type> descriptor.
366     Type indexPtrTy = LLVM::LLVMPointerType::get(
367         getTypeConverter()->getIndexType(), addressSpace);
368     Value two = rewriter.create<LLVM::ConstantOp>(
369         loc, typeConverter->convertType(rewriter.getI32Type()),
370         rewriter.getI32IntegerAttr(2));
371     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
372         loc, indexPtrTy, scalarMemRefDescPtr,
373         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
374 
375     // The size value that we have to extract can be obtained using GEPop with
376     // `dimOp.index() + 1` index argument.
377     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
378         loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
379     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
380                                                  ValueRange({idxPlusOne}));
381     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
382   }
383 
384   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
385     if (Optional<int64_t> idx = dimOp.getConstantIndex())
386       return idx;
387 
388     if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
389       return constantOp.getValue()
390           .cast<IntegerAttr>()
391           .getValue()
392           .getSExtValue();
393 
394     return llvm::None;
395   }
396 
397   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
398                                   OpAdaptor adaptor,
399                                   ConversionPatternRewriter &rewriter) const {
400     Location loc = dimOp.getLoc();
401 
402     // Take advantage if index is constant.
403     MemRefType memRefType = operandType.cast<MemRefType>();
404     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
405       int64_t i = *index;
406       if (memRefType.isDynamicDim(i)) {
407         // extract dynamic size from the memref descriptor.
408         MemRefDescriptor descriptor(adaptor.getSource());
409         return descriptor.size(rewriter, loc, i);
410       }
411       // Use constant for static size.
412       int64_t dimSize = memRefType.getDimSize(i);
413       return createIndexConstant(rewriter, loc, dimSize);
414     }
415     Value index = adaptor.getIndex();
416     int64_t rank = memRefType.getRank();
417     MemRefDescriptor memrefDescriptor(adaptor.getSource());
418     return memrefDescriptor.size(rewriter, loc, index, rank);
419   }
420 };
421 
422 /// Common base for load and store operations on MemRefs. Restricts the match
423 /// to supported MemRef types. Provides functionality to emit code accessing a
424 /// specific element of the underlying data buffer.
425 template <typename Derived>
426 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
427   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
428   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
429   using Base = LoadStoreOpLowering<Derived>;
430 
431   LogicalResult match(Derived op) const override {
432     MemRefType type = op.getMemRefType();
433     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
434   }
435 };
436 
437 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
438 /// retried until it succeeds in atomically storing a new value into memory.
439 ///
440 ///      +---------------------------------+
441 ///      |   <code before the AtomicRMWOp> |
442 ///      |   <compute initial %loaded>     |
443 ///      |   cf.br loop(%loaded)              |
444 ///      +---------------------------------+
445 ///             |
446 ///  -------|   |
447 ///  |      v   v
448 ///  |   +--------------------------------+
449 ///  |   | loop(%loaded):                 |
450 ///  |   |   <body contents>              |
451 ///  |   |   %pair = cmpxchg              |
452 ///  |   |   %ok = %pair[0]               |
453 ///  |   |   %new = %pair[1]              |
454 ///  |   |   cf.cond_br %ok, end, loop(%new) |
455 ///  |   +--------------------------------+
456 ///  |          |        |
457 ///  |-----------        |
458 ///                      v
459 ///      +--------------------------------+
460 ///      | end:                           |
461 ///      |   <code after the AtomicRMWOp> |
462 ///      +--------------------------------+
463 ///
464 struct GenericAtomicRMWOpLowering
465     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
466   using Base::Base;
467 
468   LogicalResult
469   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
470                   ConversionPatternRewriter &rewriter) const override {
471     auto loc = atomicOp.getLoc();
472     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
473 
474     // Split the block into initial, loop, and ending parts.
475     auto *initBlock = rewriter.getInsertionBlock();
476     auto *loopBlock = rewriter.createBlock(
477         initBlock->getParent(), std::next(Region::iterator(initBlock)),
478         valueType, loc);
479     auto *endBlock = rewriter.createBlock(
480         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
481 
482     // Operations range to be moved to `endBlock`.
483     auto opsToMoveStart = atomicOp->getIterator();
484     auto opsToMoveEnd = initBlock->back().getIterator();
485 
486     // Compute the loaded value and branch to the loop block.
487     rewriter.setInsertionPointToEnd(initBlock);
488     auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
489     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
490                                         adaptor.getIndices(), rewriter);
491     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
492     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
493 
494     // Prepare the body of the loop block.
495     rewriter.setInsertionPointToStart(loopBlock);
496 
497     // Clone the GenericAtomicRMWOp region and extract the result.
498     auto loopArgument = loopBlock->getArgument(0);
499     BlockAndValueMapping mapping;
500     mapping.map(atomicOp.getCurrentValue(), loopArgument);
501     Block &entryBlock = atomicOp.body().front();
502     for (auto &nestedOp : entryBlock.without_terminator()) {
503       Operation *clone = rewriter.clone(nestedOp, mapping);
504       mapping.map(nestedOp.getResults(), clone->getResults());
505     }
506     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
507 
508     // Prepare the epilog of the loop block.
509     // Append the cmpxchg op to the end of the loop block.
510     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
511     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
512     auto boolType = IntegerType::get(rewriter.getContext(), 1);
513     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
514                                                      {valueType, boolType});
515     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
516         loc, pairType, dataPtr, loopArgument, result, successOrdering,
517         failureOrdering);
518     // Extract the %new_loaded and %ok values from the pair.
519     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
520         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
521     Value ok = rewriter.create<LLVM::ExtractValueOp>(
522         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
523 
524     // Conditionally branch to the end or back to the loop depending on %ok.
525     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
526                                     loopBlock, newLoaded);
527 
528     rewriter.setInsertionPointToEnd(endBlock);
529     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
530                  std::next(opsToMoveEnd), rewriter);
531 
532     // The 'result' of the atomic_rmw op is the newly loaded value.
533     rewriter.replaceOp(atomicOp, {newLoaded});
534 
535     return success();
536   }
537 
538 private:
539   // Clones a segment of ops [start, end) and erases the original.
540   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
541                     Block::iterator start, Block::iterator end,
542                     ConversionPatternRewriter &rewriter) const {
543     BlockAndValueMapping mapping;
544     mapping.map(oldResult, newResult);
545     SmallVector<Operation *, 2> opsToErase;
546     for (auto it = start; it != end; ++it) {
547       rewriter.clone(*it, mapping);
548       opsToErase.push_back(&*it);
549     }
550     for (auto *it : opsToErase)
551       rewriter.eraseOp(it);
552   }
553 };
554 
555 /// Returns the LLVM type of the global variable given the memref type `type`.
556 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
557                                           LLVMTypeConverter &typeConverter) {
558   // LLVM type for a global memref will be a multi-dimension array. For
559   // declarations or uninitialized global memrefs, we can potentially flatten
560   // this to a 1D array. However, for memref.global's with an initial value,
561   // we do not intend to flatten the ElementsAttribute when going from std ->
562   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
563   Type elementType = typeConverter.convertType(type.getElementType());
564   Type arrayTy = elementType;
565   // Shape has the outermost dim at index 0, so need to walk it backwards
566   for (int64_t dim : llvm::reverse(type.getShape()))
567     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
568   return arrayTy;
569 }
570 
571 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
572 struct GlobalMemrefOpLowering
573     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
574   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
575 
576   LogicalResult
577   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
578                   ConversionPatternRewriter &rewriter) const override {
579     MemRefType type = global.getType();
580     if (!isConvertibleAndHasIdentityMaps(type))
581       return failure();
582 
583     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
584 
585     LLVM::Linkage linkage =
586         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
587 
588     Attribute initialValue = nullptr;
589     if (!global.isExternal() && !global.isUninitialized()) {
590       auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
591       initialValue = elementsAttr;
592 
593       // For scalar memrefs, the global variable created is of the element type,
594       // so unpack the elements attribute to extract the value.
595       if (type.getRank() == 0)
596         initialValue = elementsAttr.getSplatValue<Attribute>();
597     }
598 
599     uint64_t alignment = global.getAlignment().value_or(0);
600 
601     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
602         global, arrayTy, global.getConstant(), linkage, global.getSymName(),
603         initialValue, alignment, type.getMemorySpaceAsInt());
604     if (!global.isExternal() && global.isUninitialized()) {
605       Block *blk = new Block();
606       newGlobal.getInitializerRegion().push_back(blk);
607       rewriter.setInsertionPointToStart(blk);
608       Value undef[] = {
609           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
610       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
611     }
612     return success();
613   }
614 };
615 
616 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
617 /// the first element stashed into the descriptor. This reuses
618 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
619 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
620   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
621       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
622                                 converter) {}
623 
624   /// Buffer "allocation" for memref.get_global op is getting the address of
625   /// the global variable referenced.
626   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
627                                           Location loc, Value sizeBytes,
628                                           Operation *op) const override {
629     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
630     MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
631     unsigned memSpace = type.getMemorySpaceAsInt();
632 
633     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
634     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
635         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace),
636         getGlobalOp.getName());
637 
638     // Get the address of the first element in the array by creating a GEP with
639     // the address of the GV as the base, and (rank + 1) number of 0 indices.
640     Type elementType = typeConverter->convertType(type.getElementType());
641     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
642 
643     SmallVector<Value> operands;
644     operands.insert(operands.end(), type.getRank() + 1,
645                     createIndexConstant(rewriter, loc, 0));
646     auto gep =
647         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
648 
649     // We do not expect the memref obtained using `memref.get_global` to be
650     // ever deallocated. Set the allocated pointer to be known bad value to
651     // help debug if that ever happens.
652     auto intPtrType = getIntPtrType(memSpace);
653     Value deadBeefConst =
654         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
655     auto deadBeefPtr =
656         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
657 
658     // Both allocated and aligned pointers are same. We could potentially stash
659     // a nullptr for the allocated pointer since we do not expect any dealloc.
660     return std::make_tuple(deadBeefPtr, gep);
661   }
662 };
663 
664 // Load operation is lowered to obtaining a pointer to the indexed element
665 // and loading it.
666 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
667   using Base::Base;
668 
669   LogicalResult
670   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
671                   ConversionPatternRewriter &rewriter) const override {
672     auto type = loadOp.getMemRefType();
673 
674     Value dataPtr =
675         getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
676                              adaptor.getIndices(), rewriter);
677     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
678     return success();
679   }
680 };
681 
682 // Store operation is lowered to obtaining a pointer to the indexed element,
683 // and storing the given value to it.
684 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
685   using Base::Base;
686 
687   LogicalResult
688   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
689                   ConversionPatternRewriter &rewriter) const override {
690     auto type = op.getMemRefType();
691 
692     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
693                                          adaptor.getIndices(), rewriter);
694     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr);
695     return success();
696   }
697 };
698 
699 // The prefetch operation is lowered in a way similar to the load operation
700 // except that the llvm.prefetch operation is used for replacement.
701 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
702   using Base::Base;
703 
704   LogicalResult
705   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
706                   ConversionPatternRewriter &rewriter) const override {
707     auto type = prefetchOp.getMemRefType();
708     auto loc = prefetchOp.getLoc();
709 
710     Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
711                                          adaptor.getIndices(), rewriter);
712 
713     // Replace with llvm.prefetch.
714     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
715     auto isWrite = rewriter.create<LLVM::ConstantOp>(
716         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()));
717     auto localityHint = rewriter.create<LLVM::ConstantOp>(
718         loc, llvmI32Type,
719         rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint()));
720     auto isData = rewriter.create<LLVM::ConstantOp>(
721         loc, llvmI32Type,
722         rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache()));
723 
724     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
725                                                 localityHint, isData);
726     return success();
727   }
728 };
729 
730 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
731   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
732 
733   LogicalResult
734   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
735                   ConversionPatternRewriter &rewriter) const override {
736     Location loc = op.getLoc();
737     Type operandType = op.getMemref().getType();
738     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
739       UnrankedMemRefDescriptor desc(adaptor.getMemref());
740       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
741       return success();
742     }
743     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
744       rewriter.replaceOp(
745           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
746       return success();
747     }
748     return failure();
749   }
750 };
751 
752 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
753   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
754 
755   LogicalResult match(memref::CastOp memRefCastOp) const override {
756     Type srcType = memRefCastOp.getOperand().getType();
757     Type dstType = memRefCastOp.getType();
758 
759     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
760     // used for type erasure. For now they must preserve underlying element type
761     // and require source and result type to have the same rank. Therefore,
762     // perform a sanity check that the underlying structs are the same. Once op
763     // semantics are relaxed we can revisit.
764     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
765       return success(typeConverter->convertType(srcType) ==
766                      typeConverter->convertType(dstType));
767 
768     // At least one of the operands is unranked type
769     assert(srcType.isa<UnrankedMemRefType>() ||
770            dstType.isa<UnrankedMemRefType>());
771 
772     // Unranked to unranked cast is disallowed
773     return !(srcType.isa<UnrankedMemRefType>() &&
774              dstType.isa<UnrankedMemRefType>())
775                ? success()
776                : failure();
777   }
778 
779   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
780                ConversionPatternRewriter &rewriter) const override {
781     auto srcType = memRefCastOp.getOperand().getType();
782     auto dstType = memRefCastOp.getType();
783     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
784     auto loc = memRefCastOp.getLoc();
785 
786     // For ranked/ranked case, just keep the original descriptor.
787     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
788       return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
789 
790     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
791       // Casting ranked to unranked memref type
792       // Set the rank in the destination from the memref type
793       // Allocate space on the stack and copy the src memref descriptor
794       // Set the ptr in the destination to the stack space
795       auto srcMemRefType = srcType.cast<MemRefType>();
796       int64_t rank = srcMemRefType.getRank();
797       // ptr = AllocaOp sizeof(MemRefDescriptor)
798       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
799           loc, adaptor.getSource(), rewriter);
800       // voidptr = BitCastOp srcType* to void*
801       auto voidPtr =
802           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
803               .getResult();
804       // rank = ConstantOp srcRank
805       auto rankVal = rewriter.create<LLVM::ConstantOp>(
806           loc, getIndexType(), rewriter.getIndexAttr(rank));
807       // undef = UndefOp
808       UnrankedMemRefDescriptor memRefDesc =
809           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
810       // d1 = InsertValueOp undef, rank, 0
811       memRefDesc.setRank(rewriter, loc, rankVal);
812       // d2 = InsertValueOp d1, voidptr, 1
813       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
814       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
815 
816     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
817       // Casting from unranked type to ranked.
818       // The operation is assumed to be doing a correct cast. If the destination
819       // type mismatches the unranked the type, it is undefined behavior.
820       UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
821       // ptr = ExtractValueOp src, 1
822       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
823       // castPtr = BitCastOp i8* to structTy*
824       auto castPtr =
825           rewriter
826               .create<LLVM::BitcastOp>(
827                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
828               .getResult();
829       // struct = LoadOp castPtr
830       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
831       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
832     } else {
833       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
834     }
835   }
836 };
837 
838 /// Pattern to lower a `memref.copy` to llvm.
839 ///
840 /// For memrefs with identity layouts, the copy is lowered to the llvm
841 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
842 /// to the generic `MemrefCopyFn`.
843 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
844   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
845 
846   LogicalResult
847   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
848                           ConversionPatternRewriter &rewriter) const {
849     auto loc = op.getLoc();
850     auto srcType = op.getSource().getType().dyn_cast<MemRefType>();
851 
852     MemRefDescriptor srcDesc(adaptor.getSource());
853 
854     // Compute number of elements.
855     Value numElements = rewriter.create<LLVM::ConstantOp>(
856         loc, getIndexType(), rewriter.getIndexAttr(1));
857     for (int pos = 0; pos < srcType.getRank(); ++pos) {
858       auto size = srcDesc.size(rewriter, loc, pos);
859       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
860     }
861 
862     // Get element size.
863     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
864     // Compute total.
865     Value totalSize =
866         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
867 
868     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
869     Value srcOffset = srcDesc.offset(rewriter, loc);
870     Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
871                                                 srcBasePtr, srcOffset);
872     MemRefDescriptor targetDesc(adaptor.getTarget());
873     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
874     Value targetOffset = targetDesc.offset(rewriter, loc);
875     Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
876                                                    targetBasePtr, targetOffset);
877     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
878         loc, typeConverter->convertType(rewriter.getI1Type()),
879         rewriter.getBoolAttr(false));
880     rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
881                                     isVolatile);
882     rewriter.eraseOp(op);
883 
884     return success();
885   }
886 
887   LogicalResult
888   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
889                              ConversionPatternRewriter &rewriter) const {
890     auto loc = op.getLoc();
891     auto srcType = op.getSource().getType().cast<BaseMemRefType>();
892     auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
893 
894     // First make sure we have an unranked memref descriptor representation.
895     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
896       auto rank = rewriter.create<LLVM::ConstantOp>(
897           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
898       auto *typeConverter = getTypeConverter();
899       auto ptr =
900           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
901       auto voidPtr =
902           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
903               .getResult();
904       auto unrankedType =
905           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
906       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
907                                             unrankedType,
908                                             ValueRange{rank, voidPtr});
909     };
910 
911     Value unrankedSource = srcType.hasRank()
912                                ? makeUnranked(adaptor.getSource(), srcType)
913                                : adaptor.getSource();
914     Value unrankedTarget = targetType.hasRank()
915                                ? makeUnranked(adaptor.getTarget(), targetType)
916                                : adaptor.getTarget();
917 
918     // Now promote the unranked descriptors to the stack.
919     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
920                                                  rewriter.getIndexAttr(1));
921     auto promote = [&](Value desc) {
922       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
923       auto allocated =
924           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
925       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
926       return allocated;
927     };
928 
929     auto sourcePtr = promote(unrankedSource);
930     auto targetPtr = promote(unrankedTarget);
931 
932     unsigned typeSize =
933         mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
934     auto elemSize = rewriter.create<LLVM::ConstantOp>(
935         loc, getIndexType(), rewriter.getIndexAttr(typeSize));
936     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
937         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
938     rewriter.create<LLVM::CallOp>(loc, copyFn,
939                                   ValueRange{elemSize, sourcePtr, targetPtr});
940     rewriter.eraseOp(op);
941 
942     return success();
943   }
944 
945   LogicalResult
946   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
947                   ConversionPatternRewriter &rewriter) const override {
948     auto srcType = op.getSource().getType().cast<BaseMemRefType>();
949     auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
950 
951     auto isContiguousMemrefType = [](BaseMemRefType type) {
952       auto memrefType = type.dyn_cast<mlir::MemRefType>();
953       // We can use memcpy for memrefs if they have an identity layout or are
954       // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
955       // special case handled by memrefCopy.
956       return memrefType &&
957              (memrefType.getLayout().isIdentity() ||
958               (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
959                isStaticShapeAndContiguousRowMajor(memrefType)));
960     };
961 
962     if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
963       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
964 
965     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
966   }
967 };
968 
969 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
970 /// memref type. In unranked case, the fields are extracted from the underlying
971 /// ranked descriptor.
972 static void extractPointersAndOffset(Location loc,
973                                      ConversionPatternRewriter &rewriter,
974                                      LLVMTypeConverter &typeConverter,
975                                      Value originalOperand,
976                                      Value convertedOperand,
977                                      Value *allocatedPtr, Value *alignedPtr,
978                                      Value *offset = nullptr) {
979   Type operandType = originalOperand.getType();
980   if (operandType.isa<MemRefType>()) {
981     MemRefDescriptor desc(convertedOperand);
982     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
983     *alignedPtr = desc.alignedPtr(rewriter, loc);
984     if (offset != nullptr)
985       *offset = desc.offset(rewriter, loc);
986     return;
987   }
988 
989   unsigned memorySpace =
990       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
991   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
992   Type llvmElementType = typeConverter.convertType(elementType);
993   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
994       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
995 
996   // Extract pointer to the underlying ranked memref descriptor and cast it to
997   // ElemType**.
998   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
999   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1000 
1001   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
1002       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
1003   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1004       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1005   if (offset != nullptr) {
1006     *offset = UnrankedMemRefDescriptor::offset(
1007         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1008   }
1009 }
1010 
1011 struct MemRefReinterpretCastOpLowering
1012     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1013   using ConvertOpToLLVMPattern<
1014       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1015 
1016   LogicalResult
1017   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1018                   ConversionPatternRewriter &rewriter) const override {
1019     Type srcType = castOp.getSource().getType();
1020 
1021     Value descriptor;
1022     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1023                                                adaptor, &descriptor)))
1024       return failure();
1025     rewriter.replaceOp(castOp, {descriptor});
1026     return success();
1027   }
1028 
1029 private:
1030   LogicalResult convertSourceMemRefToDescriptor(
1031       ConversionPatternRewriter &rewriter, Type srcType,
1032       memref::ReinterpretCastOp castOp,
1033       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1034     MemRefType targetMemRefType =
1035         castOp.getResult().getType().cast<MemRefType>();
1036     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1037                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1038     if (!llvmTargetDescriptorTy)
1039       return failure();
1040 
1041     // Create descriptor.
1042     Location loc = castOp.getLoc();
1043     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1044 
1045     // Set allocated and aligned pointers.
1046     Value allocatedPtr, alignedPtr;
1047     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1048                              castOp.getSource(), adaptor.getSource(),
1049                              &allocatedPtr, &alignedPtr);
1050     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1051     desc.setAlignedPtr(rewriter, loc, alignedPtr);
1052 
1053     // Set offset.
1054     if (castOp.isDynamicOffset(0))
1055       desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1056     else
1057       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1058 
1059     // Set sizes and strides.
1060     unsigned dynSizeId = 0;
1061     unsigned dynStrideId = 0;
1062     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1063       if (castOp.isDynamicSize(i))
1064         desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1065       else
1066         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1067 
1068       if (castOp.isDynamicStride(i))
1069         desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1070       else
1071         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1072     }
1073     *descriptor = desc;
1074     return success();
1075   }
1076 };
1077 
1078 struct MemRefReshapeOpLowering
1079     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1080   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1081 
1082   LogicalResult
1083   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1084                   ConversionPatternRewriter &rewriter) const override {
1085     Type srcType = reshapeOp.getSource().getType();
1086 
1087     Value descriptor;
1088     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1089                                                adaptor, &descriptor)))
1090       return failure();
1091     rewriter.replaceOp(reshapeOp, {descriptor});
1092     return success();
1093   }
1094 
1095 private:
1096   LogicalResult
1097   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1098                                   Type srcType, memref::ReshapeOp reshapeOp,
1099                                   memref::ReshapeOp::Adaptor adaptor,
1100                                   Value *descriptor) const {
1101     auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>();
1102     if (shapeMemRefType.hasStaticShape()) {
1103       MemRefType targetMemRefType =
1104           reshapeOp.getResult().getType().cast<MemRefType>();
1105       auto llvmTargetDescriptorTy =
1106           typeConverter->convertType(targetMemRefType)
1107               .dyn_cast_or_null<LLVM::LLVMStructType>();
1108       if (!llvmTargetDescriptorTy)
1109         return failure();
1110 
1111       // Create descriptor.
1112       Location loc = reshapeOp.getLoc();
1113       auto desc =
1114           MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1115 
1116       // Set allocated and aligned pointers.
1117       Value allocatedPtr, alignedPtr;
1118       extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1119                                reshapeOp.getSource(), adaptor.getSource(),
1120                                &allocatedPtr, &alignedPtr);
1121       desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1122       desc.setAlignedPtr(rewriter, loc, alignedPtr);
1123 
1124       // Extract the offset and strides from the type.
1125       int64_t offset;
1126       SmallVector<int64_t> strides;
1127       if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1128         return rewriter.notifyMatchFailure(
1129             reshapeOp, "failed to get stride and offset exprs");
1130 
1131       if (!isStaticStrideOrOffset(offset))
1132         return rewriter.notifyMatchFailure(reshapeOp,
1133                                            "dynamic offset is unsupported");
1134 
1135       desc.setConstantOffset(rewriter, loc, offset);
1136 
1137       assert(targetMemRefType.getLayout().isIdentity() &&
1138              "Identity layout map is a precondition of a valid reshape op");
1139 
1140       Value stride = nullptr;
1141       int64_t targetRank = targetMemRefType.getRank();
1142       for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1143         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1144           // If the stride for this dimension is dynamic, then use the product
1145           // of the sizes of the inner dimensions.
1146           stride = createIndexConstant(rewriter, loc, strides[i]);
1147         } else if (!stride) {
1148           // `stride` is null only in the first iteration of the loop.  However,
1149           // since the target memref has an identity layout, we can safely set
1150           // the innermost stride to 1.
1151           stride = createIndexConstant(rewriter, loc, 1);
1152         }
1153 
1154         Value dimSize;
1155         int64_t size = targetMemRefType.getDimSize(i);
1156         // If the size of this dimension is dynamic, then load it at runtime
1157         // from the shape operand.
1158         if (!ShapedType::isDynamic(size)) {
1159           dimSize = createIndexConstant(rewriter, loc, size);
1160         } else {
1161           Value shapeOp = reshapeOp.getShape();
1162           Value index = createIndexConstant(rewriter, loc, i);
1163           dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1164         }
1165 
1166         desc.setSize(rewriter, loc, i, dimSize);
1167         desc.setStride(rewriter, loc, i, stride);
1168 
1169         // Prepare the stride value for the next dimension.
1170         stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1171       }
1172 
1173       *descriptor = desc;
1174       return success();
1175     }
1176 
1177     // The shape is a rank-1 tensor with unknown length.
1178     Location loc = reshapeOp.getLoc();
1179     MemRefDescriptor shapeDesc(adaptor.getShape());
1180     Value resultRank = shapeDesc.size(rewriter, loc, 0);
1181 
1182     // Extract address space and element type.
1183     auto targetType =
1184         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1185     unsigned addressSpace = targetType.getMemorySpaceAsInt();
1186     Type elementType = targetType.getElementType();
1187 
1188     // Create the unranked memref descriptor that holds the ranked one. The
1189     // inner descriptor is allocated on stack.
1190     auto targetDesc = UnrankedMemRefDescriptor::undef(
1191         rewriter, loc, typeConverter->convertType(targetType));
1192     targetDesc.setRank(rewriter, loc, resultRank);
1193     SmallVector<Value, 4> sizes;
1194     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1195                                            targetDesc, sizes);
1196     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1197         loc, getVoidPtrType(), sizes.front(), llvm::None);
1198     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1199 
1200     // Extract pointers and offset from the source memref.
1201     Value allocatedPtr, alignedPtr, offset;
1202     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1203                              reshapeOp.getSource(), adaptor.getSource(),
1204                              &allocatedPtr, &alignedPtr, &offset);
1205 
1206     // Set pointers and offset.
1207     Type llvmElementType = typeConverter->convertType(elementType);
1208     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1209         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1210     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1211                                               elementPtrPtrType, allocatedPtr);
1212     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1213                                             underlyingDescPtr,
1214                                             elementPtrPtrType, alignedPtr);
1215     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1216                                         underlyingDescPtr, elementPtrPtrType,
1217                                         offset);
1218 
1219     // Use the offset pointer as base for further addressing. Copy over the new
1220     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1221     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1222         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1223         elementPtrPtrType);
1224     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1225         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1226     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1227     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1228     Value resultRankMinusOne =
1229         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1230 
1231     Block *initBlock = rewriter.getInsertionBlock();
1232     Type indexType = getTypeConverter()->getIndexType();
1233     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1234 
1235     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1236                                             {indexType, indexType}, {loc, loc});
1237 
1238     // Move the remaining initBlock ops to condBlock.
1239     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1240     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1241 
1242     rewriter.setInsertionPointToEnd(initBlock);
1243     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1244                                 condBlock);
1245     rewriter.setInsertionPointToStart(condBlock);
1246     Value indexArg = condBlock->getArgument(0);
1247     Value strideArg = condBlock->getArgument(1);
1248 
1249     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1250     Value pred = rewriter.create<LLVM::ICmpOp>(
1251         loc, IntegerType::get(rewriter.getContext(), 1),
1252         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1253 
1254     Block *bodyBlock =
1255         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1256     rewriter.setInsertionPointToStart(bodyBlock);
1257 
1258     // Copy size from shape to descriptor.
1259     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1260     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1261         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1262     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1263     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1264                                       targetSizesBase, indexArg, size);
1265 
1266     // Write stride value and compute next one.
1267     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1268                                         targetStridesBase, indexArg, strideArg);
1269     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1270 
1271     // Decrement loop counter and branch back.
1272     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1273     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1274                                 condBlock);
1275 
1276     Block *remainder =
1277         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1278 
1279     // Hook up the cond exit to the remainder.
1280     rewriter.setInsertionPointToEnd(condBlock);
1281     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1282                                     llvm::None);
1283 
1284     // Reset position to beginning of new remainder block.
1285     rewriter.setInsertionPointToStart(remainder);
1286 
1287     *descriptor = targetDesc;
1288     return success();
1289   }
1290 };
1291 
1292 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1293 /// `Value`s.
1294 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1295                                       Type &llvmIndexType,
1296                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1297   return llvm::to_vector<4>(
1298       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1299         if (auto attr = value.dyn_cast<Attribute>())
1300           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1301         return value.get<Value>();
1302       }));
1303 }
1304 
1305 /// Compute a map that for a given dimension of the expanded type gives the
1306 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1307 /// the `reassocation` maps.
1308 static DenseMap<int64_t, int64_t>
1309 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1310   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1311   for (auto &en : enumerate(reassociation)) {
1312     for (auto dim : en.value())
1313       expandedDimToCollapsedDim[dim] = en.index();
1314   }
1315   return expandedDimToCollapsedDim;
1316 }
1317 
1318 static OpFoldResult
1319 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1320                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1321                          MemRefDescriptor &inDesc,
1322                          ArrayRef<int64_t> inStaticShape,
1323                          ArrayRef<ReassociationIndices> reassocation,
1324                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1325   int64_t outDimSize = outStaticShape[outDimIndex];
1326   if (!ShapedType::isDynamic(outDimSize))
1327     return b.getIndexAttr(outDimSize);
1328 
1329   // Calculate the multiplication of all the out dim sizes except the
1330   // current dim.
1331   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1332   int64_t otherDimSizesMul = 1;
1333   for (auto otherDimIndex : reassocation[inDimIndex]) {
1334     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1335       continue;
1336     int64_t otherDimSize = outStaticShape[otherDimIndex];
1337     assert(!ShapedType::isDynamic(otherDimSize) &&
1338            "single dimension cannot be expanded into multiple dynamic "
1339            "dimensions");
1340     otherDimSizesMul *= otherDimSize;
1341   }
1342 
1343   // outDimSize = inDimSize / otherOutDimSizesMul
1344   int64_t inDimSize = inStaticShape[inDimIndex];
1345   Value inDimSizeDynamic =
1346       ShapedType::isDynamic(inDimSize)
1347           ? inDesc.size(b, loc, inDimIndex)
1348           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1349                                        b.getIndexAttr(inDimSize));
1350   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1351       loc, inDimSizeDynamic,
1352       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1353                                  b.getIndexAttr(otherDimSizesMul)));
1354   return outDimSizeDynamic;
1355 }
1356 
1357 static OpFoldResult getCollapsedOutputDimSize(
1358     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1359     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1360     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1361   if (!ShapedType::isDynamic(outDimSize))
1362     return b.getIndexAttr(outDimSize);
1363 
1364   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1365   Value outDimSizeDynamic = c1;
1366   for (auto inDimIndex : reassocation[outDimIndex]) {
1367     int64_t inDimSize = inStaticShape[inDimIndex];
1368     Value inDimSizeDynamic =
1369         ShapedType::isDynamic(inDimSize)
1370             ? inDesc.size(b, loc, inDimIndex)
1371             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1372                                          b.getIndexAttr(inDimSize));
1373     outDimSizeDynamic =
1374         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1375   }
1376   return outDimSizeDynamic;
1377 }
1378 
1379 static SmallVector<OpFoldResult, 4>
1380 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1381                         ArrayRef<ReassociationIndices> reassociation,
1382                         ArrayRef<int64_t> inStaticShape,
1383                         MemRefDescriptor &inDesc,
1384                         ArrayRef<int64_t> outStaticShape) {
1385   return llvm::to_vector<4>(llvm::map_range(
1386       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1387         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1388                                          outStaticShape[outDimIndex],
1389                                          inStaticShape, inDesc, reassociation);
1390       }));
1391 }
1392 
1393 static SmallVector<OpFoldResult, 4>
1394 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1395                        ArrayRef<ReassociationIndices> reassociation,
1396                        ArrayRef<int64_t> inStaticShape,
1397                        MemRefDescriptor &inDesc,
1398                        ArrayRef<int64_t> outStaticShape) {
1399   DenseMap<int64_t, int64_t> outDimToInDimMap =
1400       getExpandedDimToCollapsedDimMap(reassociation);
1401   return llvm::to_vector<4>(llvm::map_range(
1402       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1403         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1404                                         outStaticShape, inDesc, inStaticShape,
1405                                         reassociation, outDimToInDimMap);
1406       }));
1407 }
1408 
1409 static SmallVector<Value>
1410 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1411                       ArrayRef<ReassociationIndices> reassociation,
1412                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1413                       ArrayRef<int64_t> outStaticShape) {
1414   return outStaticShape.size() < inStaticShape.size()
1415              ? getAsValues(b, loc, llvmIndexType,
1416                            getCollapsedOutputShape(b, loc, llvmIndexType,
1417                                                    reassociation, inStaticShape,
1418                                                    inDesc, outStaticShape))
1419              : getAsValues(b, loc, llvmIndexType,
1420                            getExpandedOutputShape(b, loc, llvmIndexType,
1421                                                   reassociation, inStaticShape,
1422                                                   inDesc, outStaticShape));
1423 }
1424 
1425 static void fillInStridesForExpandedMemDescriptor(
1426     OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
1427     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1428   // See comments for computeExpandedLayoutMap for details on how the strides
1429   // are calculated.
1430   for (auto &en : llvm::enumerate(reassociation)) {
1431     auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
1432     for (auto dstIndex : llvm::reverse(en.value())) {
1433       dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
1434       Value size = dstDesc.size(b, loc, dstIndex);
1435       currentStrideToExpand =
1436           b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
1437     }
1438   }
1439 }
1440 
1441 static void fillInStridesForCollapsedMemDescriptor(
1442     ConversionPatternRewriter &rewriter, Location loc, Operation *op,
1443     TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
1444     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1445   // See comments for computeCollapsedLayoutMap for details on how the strides
1446   // are calculated.
1447   auto srcShape = srcType.getShape();
1448   for (auto &en : llvm::enumerate(reassociation)) {
1449     rewriter.setInsertionPoint(op);
1450     auto dstIndex = en.index();
1451     ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value());
1452     while (srcShape[ref.back()] == 1 && ref.size() > 1)
1453       ref = ref.drop_back();
1454     if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1455       dstDesc.setStride(rewriter, loc, dstIndex,
1456                         srcDesc.stride(rewriter, loc, ref.back()));
1457     } else {
1458       // Iterate over the source strides in reverse order. Skip over the
1459       // dimensions whose size is 1.
1460       // TODO: we should take the minimum stride in the reassociation group
1461       // instead of just the first where the dimension is not 1.
1462       //
1463       // +------------------------------------------------------+
1464       // | curEntry:                                            |
1465       // |   %srcStride = strides[srcIndex]                     |
1466       // |   %neOne = cmp sizes[srcIndex],1                     +--+
1467       // |   cf.cond_br %neOne, continue(%srcStride), nextEntry |  |
1468       // +-------------------------+----------------------------+  |
1469       //                           |                               |
1470       //                           v                               |
1471       //            +-----------------------------+                |
1472       //            | nextEntry:                  |                |
1473       //            |   ...                       +---+            |
1474       //            +--------------+--------------+   |            |
1475       //                           |                  |            |
1476       //                           v                  |            |
1477       //            +-----------------------------+   |            |
1478       //            | nextEntry:                  |   |            |
1479       //            |   ...                       |   |            |
1480       //            +--------------+--------------+   |   +--------+
1481       //                           |                  |   |
1482       //                           v                  v   v
1483       //   +--------------------------------------------------+
1484       //   | continue(%newStride):                            |
1485       //   |   %newMemRefDes = setStride(%newStride,dstIndex) |
1486       //   +--------------------------------------------------+
1487       OpBuilder::InsertionGuard guard(rewriter);
1488       Block *initBlock = rewriter.getInsertionBlock();
1489       Block *continueBlock =
1490           rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
1491       continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
1492       rewriter.setInsertionPointToStart(continueBlock);
1493       dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
1494 
1495       Block *curEntryBlock = initBlock;
1496       Block *nextEntryBlock;
1497       for (auto srcIndex : llvm::reverse(ref)) {
1498         if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
1499           continue;
1500         rewriter.setInsertionPointToEnd(curEntryBlock);
1501         Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
1502         if (srcIndex == ref.front()) {
1503           rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
1504           break;
1505         }
1506         Value one = rewriter.create<LLVM::ConstantOp>(
1507             loc, typeConverter->convertType(rewriter.getI64Type()),
1508             rewriter.getI32IntegerAttr(1));
1509         Value predNeOne = rewriter.create<LLVM::ICmpOp>(
1510             loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
1511             one);
1512         {
1513           OpBuilder::InsertionGuard guard(rewriter);
1514           nextEntryBlock = rewriter.createBlock(
1515               initBlock->getParent(), Region::iterator(continueBlock), {});
1516         }
1517         rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
1518                                         srcStride, nextEntryBlock, llvm::None);
1519         curEntryBlock = nextEntryBlock;
1520       }
1521     }
1522   }
1523 }
1524 
1525 static void fillInDynamicStridesForMemDescriptor(
1526     ConversionPatternRewriter &b, Location loc, Operation *op,
1527     TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
1528     MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
1529     ArrayRef<ReassociationIndices> reassociation) {
1530   if (srcType.getRank() > dstType.getRank())
1531     fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
1532                                            srcDesc, dstDesc, reassociation);
1533   else
1534     fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
1535                                           reassociation);
1536 }
1537 
1538 // ReshapeOp creates a new view descriptor of the proper rank.
1539 // For now, the only conversion supported is for target MemRef with static sizes
1540 // and strides.
1541 template <typename ReshapeOp>
1542 class ReassociatingReshapeOpConversion
1543     : public ConvertOpToLLVMPattern<ReshapeOp> {
1544 public:
1545   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1546   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1547 
1548   LogicalResult
1549   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1550                   ConversionPatternRewriter &rewriter) const override {
1551     MemRefType dstType = reshapeOp.getResultType();
1552     MemRefType srcType = reshapeOp.getSrcType();
1553 
1554     int64_t offset;
1555     SmallVector<int64_t, 4> strides;
1556     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1557       return rewriter.notifyMatchFailure(
1558           reshapeOp, "failed to get stride and offset exprs");
1559     }
1560 
1561     MemRefDescriptor srcDesc(adaptor.getSrc());
1562     Location loc = reshapeOp->getLoc();
1563     auto dstDesc = MemRefDescriptor::undef(
1564         rewriter, loc, this->typeConverter->convertType(dstType));
1565     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1566     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1567     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1568 
1569     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1570     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1571     Type llvmIndexType =
1572         this->typeConverter->convertType(rewriter.getIndexType());
1573     SmallVector<Value> dstShape = getDynamicOutputShape(
1574         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1575         srcStaticShape, srcDesc, dstStaticShape);
1576     for (auto &en : llvm::enumerate(dstShape))
1577       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1578 
1579     if (llvm::all_of(strides, isStaticStrideOrOffset)) {
1580       for (auto &en : llvm::enumerate(strides))
1581         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1582     } else if (srcType.getLayout().isIdentity() &&
1583                dstType.getLayout().isIdentity()) {
1584       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1585                                                    rewriter.getIndexAttr(1));
1586       Value stride = c1;
1587       for (auto dimIndex :
1588            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1589         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1590         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1591       }
1592     } else {
1593       // There could be mixed static/dynamic strides. For simplicity, we
1594       // recompute all strides if there is at least one dynamic stride.
1595       fillInDynamicStridesForMemDescriptor(
1596           rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
1597           srcDesc, dstDesc, reshapeOp.getReassociationIndices());
1598     }
1599     rewriter.replaceOp(reshapeOp, {dstDesc});
1600     return success();
1601   }
1602 };
1603 
1604 /// Conversion pattern that transforms a subview op into:
1605 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1606 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1607 ///      and stride.
1608 /// The subview op is replaced by the descriptor.
1609 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1610   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1611 
1612   LogicalResult
1613   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1614                   ConversionPatternRewriter &rewriter) const override {
1615     auto loc = subViewOp.getLoc();
1616 
1617     auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>();
1618     auto sourceElementTy =
1619         typeConverter->convertType(sourceMemRefType.getElementType());
1620 
1621     auto viewMemRefType = subViewOp.getType();
1622     auto inferredType =
1623         memref::SubViewOp::inferResultType(
1624             subViewOp.getSourceType(),
1625             extractFromI64ArrayAttr(subViewOp.getStaticOffsets()),
1626             extractFromI64ArrayAttr(subViewOp.getStaticSizes()),
1627             extractFromI64ArrayAttr(subViewOp.getStaticStrides()))
1628             .cast<MemRefType>();
1629     auto targetElementTy =
1630         typeConverter->convertType(viewMemRefType.getElementType());
1631     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1632     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1633         !LLVM::isCompatibleType(sourceElementTy) ||
1634         !LLVM::isCompatibleType(targetElementTy) ||
1635         !LLVM::isCompatibleType(targetDescTy))
1636       return failure();
1637 
1638     // Extract the offset and strides from the type.
1639     int64_t offset;
1640     SmallVector<int64_t, 4> strides;
1641     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1642     if (failed(successStrides))
1643       return failure();
1644 
1645     // Create the descriptor.
1646     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1647       return failure();
1648     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1649     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1650 
1651     // Copy the buffer pointer from the old descriptor to the new one.
1652     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1653     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1654         loc,
1655         LLVM::LLVMPointerType::get(targetElementTy,
1656                                    viewMemRefType.getMemorySpaceAsInt()),
1657         extracted);
1658     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1659 
1660     // Copy the aligned pointer from the old descriptor to the new one.
1661     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1662     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1663         loc,
1664         LLVM::LLVMPointerType::get(targetElementTy,
1665                                    viewMemRefType.getMemorySpaceAsInt()),
1666         extracted);
1667     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1668 
1669     size_t inferredShapeRank = inferredType.getRank();
1670     size_t resultShapeRank = viewMemRefType.getRank();
1671 
1672     // Extract strides needed to compute offset.
1673     SmallVector<Value, 4> strideValues;
1674     strideValues.reserve(inferredShapeRank);
1675     for (unsigned i = 0; i < inferredShapeRank; ++i)
1676       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1677 
1678     // Offset.
1679     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1680     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1681       targetMemRef.setConstantOffset(rewriter, loc, offset);
1682     } else {
1683       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1684       // `inferredShapeRank` may be larger than the number of offset operands
1685       // because of trailing semantics. In this case, the offset is guaranteed
1686       // to be interpreted as 0 and we can just skip the extra dimensions.
1687       for (unsigned i = 0, e = std::min(inferredShapeRank,
1688                                         subViewOp.getMixedOffsets().size());
1689            i < e; ++i) {
1690         Value offset =
1691             // TODO: need OpFoldResult ODS adaptor to clean this up.
1692             subViewOp.isDynamicOffset(i)
1693                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1694                 : rewriter.create<LLVM::ConstantOp>(
1695                       loc, llvmIndexType,
1696                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1697         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1698         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1699       }
1700       targetMemRef.setOffset(rewriter, loc, baseOffset);
1701     }
1702 
1703     // Update sizes and strides.
1704     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1705     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1706     assert(mixedSizes.size() == mixedStrides.size() &&
1707            "expected sizes and strides of equal length");
1708     llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1709     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1710          i >= 0 && j >= 0; --i) {
1711       if (unusedDims.test(i))
1712         continue;
1713 
1714       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1715       // In this case, the size is guaranteed to be interpreted as Dim and the
1716       // stride as 1.
1717       Value size, stride;
1718       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1719         // If the static size is available, use it directly. This is similar to
1720         // the folding of dim(constant-op) but removes the need for dim to be
1721         // aware of LLVM constants and for this pass to be aware of std
1722         // constants.
1723         int64_t staticSize =
1724             subViewOp.getSource().getType().cast<MemRefType>().getShape()[i];
1725         if (staticSize != ShapedType::kDynamicSize) {
1726           size = rewriter.create<LLVM::ConstantOp>(
1727               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1728         } else {
1729           Value pos = rewriter.create<LLVM::ConstantOp>(
1730               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1731           Value dim =
1732               rewriter.create<memref::DimOp>(loc, subViewOp.getSource(), pos);
1733           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1734               loc, llvmIndexType, dim);
1735           size = cast.getResult(0);
1736         }
1737         stride = rewriter.create<LLVM::ConstantOp>(
1738             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1739       } else {
1740         // TODO: need OpFoldResult ODS adaptor to clean this up.
1741         size =
1742             subViewOp.isDynamicSize(i)
1743                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1744                 : rewriter.create<LLVM::ConstantOp>(
1745                       loc, llvmIndexType,
1746                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1747         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1748           stride = rewriter.create<LLVM::ConstantOp>(
1749               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1750         } else {
1751           stride =
1752               subViewOp.isDynamicStride(i)
1753                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1754                   : rewriter.create<LLVM::ConstantOp>(
1755                         loc, llvmIndexType,
1756                         rewriter.getI64IntegerAttr(
1757                             subViewOp.getStaticStride(i)));
1758           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1759         }
1760       }
1761       targetMemRef.setSize(rewriter, loc, j, size);
1762       targetMemRef.setStride(rewriter, loc, j, stride);
1763       j--;
1764     }
1765 
1766     rewriter.replaceOp(subViewOp, {targetMemRef});
1767     return success();
1768   }
1769 };
1770 
1771 /// Conversion pattern that transforms a transpose op into:
1772 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1773 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1774 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1775 ///      and stride. Size and stride are permutations of the original values.
1776 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1777 /// The transpose op is replaced by the alloca'ed pointer.
1778 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1779 public:
1780   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1781 
1782   LogicalResult
1783   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1784                   ConversionPatternRewriter &rewriter) const override {
1785     auto loc = transposeOp.getLoc();
1786     MemRefDescriptor viewMemRef(adaptor.getIn());
1787 
1788     // No permutation, early exit.
1789     if (transposeOp.getPermutation().isIdentity())
1790       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1791 
1792     auto targetMemRef = MemRefDescriptor::undef(
1793         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1794 
1795     // Copy the base and aligned pointers from the old descriptor to the new
1796     // one.
1797     targetMemRef.setAllocatedPtr(rewriter, loc,
1798                                  viewMemRef.allocatedPtr(rewriter, loc));
1799     targetMemRef.setAlignedPtr(rewriter, loc,
1800                                viewMemRef.alignedPtr(rewriter, loc));
1801 
1802     // Copy the offset pointer from the old descriptor to the new one.
1803     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1804 
1805     // Iterate over the dimensions and apply size/stride permutation.
1806     for (const auto &en :
1807          llvm::enumerate(transposeOp.getPermutation().getResults())) {
1808       int sourcePos = en.index();
1809       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1810       targetMemRef.setSize(rewriter, loc, targetPos,
1811                            viewMemRef.size(rewriter, loc, sourcePos));
1812       targetMemRef.setStride(rewriter, loc, targetPos,
1813                              viewMemRef.stride(rewriter, loc, sourcePos));
1814     }
1815 
1816     rewriter.replaceOp(transposeOp, {targetMemRef});
1817     return success();
1818   }
1819 };
1820 
1821 /// Conversion pattern that transforms an op into:
1822 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1823 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1824 ///      and stride.
1825 /// The view op is replaced by the descriptor.
1826 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1827   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1828 
1829   // Build and return the value for the idx^th shape dimension, either by
1830   // returning the constant shape dimension or counting the proper dynamic size.
1831   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1832                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1833                 unsigned idx) const {
1834     assert(idx < shape.size());
1835     if (!ShapedType::isDynamic(shape[idx]))
1836       return createIndexConstant(rewriter, loc, shape[idx]);
1837     // Count the number of dynamic dims in range [0, idx]
1838     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1839       return ShapedType::isDynamic(v);
1840     });
1841     return dynamicSizes[nDynamic];
1842   }
1843 
1844   // Build and return the idx^th stride, either by returning the constant stride
1845   // or by computing the dynamic stride from the current `runningStride` and
1846   // `nextSize`. The caller should keep a running stride and update it with the
1847   // result returned by this function.
1848   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1849                   ArrayRef<int64_t> strides, Value nextSize,
1850                   Value runningStride, unsigned idx) const {
1851     assert(idx < strides.size());
1852     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1853       return createIndexConstant(rewriter, loc, strides[idx]);
1854     if (nextSize)
1855       return runningStride
1856                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1857                  : nextSize;
1858     assert(!runningStride);
1859     return createIndexConstant(rewriter, loc, 1);
1860   }
1861 
1862   LogicalResult
1863   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1864                   ConversionPatternRewriter &rewriter) const override {
1865     auto loc = viewOp.getLoc();
1866 
1867     auto viewMemRefType = viewOp.getType();
1868     auto targetElementTy =
1869         typeConverter->convertType(viewMemRefType.getElementType());
1870     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1871     if (!targetDescTy || !targetElementTy ||
1872         !LLVM::isCompatibleType(targetElementTy) ||
1873         !LLVM::isCompatibleType(targetDescTy))
1874       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1875              failure();
1876 
1877     int64_t offset;
1878     SmallVector<int64_t, 4> strides;
1879     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1880     if (failed(successStrides))
1881       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1882     assert(offset == 0 && "expected offset to be 0");
1883 
1884     // Target memref must be contiguous in memory (innermost stride is 1), or
1885     // empty (special case when at least one of the memref dimensions is 0).
1886     if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1887       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1888              failure();
1889 
1890     // Create the descriptor.
1891     MemRefDescriptor sourceMemRef(adaptor.getSource());
1892     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1893 
1894     // Field 1: Copy the allocated pointer, used for malloc/free.
1895     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1896     auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
1897     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1898         loc,
1899         LLVM::LLVMPointerType::get(targetElementTy,
1900                                    srcMemRefType.getMemorySpaceAsInt()),
1901         allocatedPtr);
1902     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1903 
1904     // Field 2: Copy the actual aligned pointer to payload.
1905     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1906     alignedPtr = rewriter.create<LLVM::GEPOp>(
1907         loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift());
1908     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1909         loc,
1910         LLVM::LLVMPointerType::get(targetElementTy,
1911                                    srcMemRefType.getMemorySpaceAsInt()),
1912         alignedPtr);
1913     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1914 
1915     // Field 3: The offset in the resulting type must be 0. This is because of
1916     // the type change: an offset on srcType* may not be expressible as an
1917     // offset on dstType*.
1918     targetMemRef.setOffset(rewriter, loc,
1919                            createIndexConstant(rewriter, loc, offset));
1920 
1921     // Early exit for 0-D corner case.
1922     if (viewMemRefType.getRank() == 0)
1923       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1924 
1925     // Fields 4 and 5: Update sizes and strides.
1926     Value stride = nullptr, nextSize = nullptr;
1927     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1928       // Update size.
1929       Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1930                            adaptor.getSizes(), i);
1931       targetMemRef.setSize(rewriter, loc, i, size);
1932       // Update stride.
1933       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1934       targetMemRef.setStride(rewriter, loc, i, stride);
1935       nextSize = size;
1936     }
1937 
1938     rewriter.replaceOp(viewOp, {targetMemRef});
1939     return success();
1940   }
1941 };
1942 
1943 //===----------------------------------------------------------------------===//
1944 // AtomicRMWOpLowering
1945 //===----------------------------------------------------------------------===//
1946 
1947 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1948 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1949 static Optional<LLVM::AtomicBinOp>
1950 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1951   switch (atomicOp.getKind()) {
1952   case arith::AtomicRMWKind::addf:
1953     return LLVM::AtomicBinOp::fadd;
1954   case arith::AtomicRMWKind::addi:
1955     return LLVM::AtomicBinOp::add;
1956   case arith::AtomicRMWKind::assign:
1957     return LLVM::AtomicBinOp::xchg;
1958   case arith::AtomicRMWKind::maxs:
1959     return LLVM::AtomicBinOp::max;
1960   case arith::AtomicRMWKind::maxu:
1961     return LLVM::AtomicBinOp::umax;
1962   case arith::AtomicRMWKind::mins:
1963     return LLVM::AtomicBinOp::min;
1964   case arith::AtomicRMWKind::minu:
1965     return LLVM::AtomicBinOp::umin;
1966   case arith::AtomicRMWKind::ori:
1967     return LLVM::AtomicBinOp::_or;
1968   case arith::AtomicRMWKind::andi:
1969     return LLVM::AtomicBinOp::_and;
1970   default:
1971     return llvm::None;
1972   }
1973   llvm_unreachable("Invalid AtomicRMWKind");
1974 }
1975 
1976 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1977   using Base::Base;
1978 
1979   LogicalResult
1980   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1981                   ConversionPatternRewriter &rewriter) const override {
1982     if (failed(match(atomicOp)))
1983       return failure();
1984     auto maybeKind = matchSimpleAtomicOp(atomicOp);
1985     if (!maybeKind)
1986       return failure();
1987     auto resultType = adaptor.getValue().getType();
1988     auto memRefType = atomicOp.getMemRefType();
1989     auto dataPtr =
1990         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1991                              adaptor.getIndices(), rewriter);
1992     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1993         atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
1994         LLVM::AtomicOrdering::acq_rel);
1995     return success();
1996   }
1997 };
1998 
1999 } // namespace
2000 
2001 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
2002                                                   RewritePatternSet &patterns) {
2003   // clang-format off
2004   patterns.add<
2005       AllocaOpLowering,
2006       AllocaScopeOpLowering,
2007       AtomicRMWOpLowering,
2008       AssumeAlignmentOpLowering,
2009       DimOpLowering,
2010       GenericAtomicRMWOpLowering,
2011       GlobalMemrefOpLowering,
2012       GetGlobalMemrefOpLowering,
2013       LoadOpLowering,
2014       MemRefCastOpLowering,
2015       MemRefCopyOpLowering,
2016       MemRefReinterpretCastOpLowering,
2017       MemRefReshapeOpLowering,
2018       PrefetchOpLowering,
2019       RankOpLowering,
2020       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2021       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2022       StoreOpLowering,
2023       SubViewOpLowering,
2024       TransposeOpLowering,
2025       ViewOpLowering>(converter);
2026   // clang-format on
2027   auto allocLowering = converter.getOptions().allocLowering;
2028   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
2029     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
2030   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2031     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
2032 }
2033 
2034 namespace {
2035 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
2036   MemRefToLLVMPass() = default;
2037 
2038   void runOnOperation() override {
2039     Operation *op = getOperation();
2040     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2041     LowerToLLVMOptions options(&getContext(),
2042                                dataLayoutAnalysis.getAtOrAbove(op));
2043     options.allocLowering =
2044         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
2045                          : LowerToLLVMOptions::AllocLowering::Malloc);
2046     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2047       options.overrideIndexBitwidth(indexBitwidth);
2048 
2049     LLVMTypeConverter typeConverter(&getContext(), options,
2050                                     &dataLayoutAnalysis);
2051     RewritePatternSet patterns(&getContext());
2052     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
2053     LLVMConversionTarget target(getContext());
2054     target.addLegalOp<func::FuncOp>();
2055     if (failed(applyPartialConversion(op, target, std::move(patterns))))
2056       signalPassFailure();
2057   }
2058 };
2059 } // namespace
2060 
2061 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
2062   return std::make_unique<MemRefToLLVMPass>();
2063 }
2064