1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
27   AllocOpLowering(LLVMTypeConverter &converter)
28       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
29                                 converter) {}
30 
31   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
32                                           Location loc, Value sizeBytes,
33                                           Operation *op) const override {
34     // Heap allocations.
35     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
36     MemRefType memRefType = allocOp.getType();
37 
38     Value alignment;
39     if (auto alignmentAttr = allocOp.alignment()) {
40       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
41     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
42       // In the case where no alignment is specified, we may want to override
43       // `malloc's` behavior. `malloc` typically aligns at the size of the
44       // biggest scalar on a target HW. For non-scalars, use the natural
45       // alignment of the LLVM type given by the LLVM DataLayout.
46       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
47     }
48 
49     if (alignment) {
50       // Adjust the allocation size to consider alignment.
51       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
52     }
53 
54     // Allocate the underlying buffer and store a pointer to it in the MemRef
55     // descriptor.
56     Type elementPtrType = this->getElementPtrType(memRefType);
57     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
58         allocOp->getParentOfType<ModuleOp>(), getIndexType());
59     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
60                                   getVoidPtrType());
61     Value allocatedPtr =
62         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
63 
64     Value alignedPtr = allocatedPtr;
65     if (alignment) {
66       // Compute the aligned type pointer.
67       Value allocatedInt =
68           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
69       Value alignmentInt =
70           createAligned(rewriter, loc, allocatedInt, alignment);
71       alignedPtr =
72           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
73     }
74 
75     return std::make_tuple(allocatedPtr, alignedPtr);
76   }
77 };
78 
79 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
80   AlignedAllocOpLowering(LLVMTypeConverter &converter)
81       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
82                                 converter) {}
83 
84   /// Returns the memref's element size in bytes using the data layout active at
85   /// `op`.
86   // TODO: there are other places where this is used. Expose publicly?
87   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
88     const DataLayout *layout = &defaultLayout;
89     if (const DataLayoutAnalysis *analysis =
90             getTypeConverter()->getDataLayoutAnalysis()) {
91       layout = &analysis->getAbove(op);
92     }
93     Type elementType = memRefType.getElementType();
94     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
95       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
96                                                          *layout);
97     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
98       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
99           memRefElementType, *layout);
100     return layout->getTypeSize(elementType);
101   }
102 
103   /// Returns true if the memref size in bytes is known to be a multiple of
104   /// factor assuming the data layout active at `op`.
105   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
106                               Operation *op) const {
107     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
108     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
109       if (ShapedType::isDynamic(type.getDimSize(i)))
110         continue;
111       sizeDivisor = sizeDivisor * type.getDimSize(i);
112     }
113     return sizeDivisor % factor == 0;
114   }
115 
116   /// Returns the alignment to be used for the allocation call itself.
117   /// aligned_alloc requires the allocation size to be a power of two, and the
118   /// allocation size to be a multiple of alignment,
119   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
120     if (Optional<uint64_t> alignment = allocOp.alignment())
121       return *alignment;
122 
123     // Whenever we don't have alignment set, we will use an alignment
124     // consistent with the element type; since the allocation size has to be a
125     // power of two, we will bump to the next power of two if it already isn't.
126     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
127     return std::max(kMinAlignedAllocAlignment,
128                     llvm::PowerOf2Ceil(eltSizeBytes));
129   }
130 
131   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
132                                           Location loc, Value sizeBytes,
133                                           Operation *op) const override {
134     // Heap allocations.
135     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
136     MemRefType memRefType = allocOp.getType();
137     int64_t alignment = getAllocationAlignment(allocOp);
138     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
139 
140     // aligned_alloc requires size to be a multiple of alignment; we will pad
141     // the size to the next multiple if necessary.
142     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
143       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
144 
145     Type elementPtrType = this->getElementPtrType(memRefType);
146     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
147         allocOp->getParentOfType<ModuleOp>(), getIndexType());
148     auto results =
149         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
150                        getVoidPtrType());
151     Value allocatedPtr =
152         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
153 
154     return std::make_tuple(allocatedPtr, allocatedPtr);
155   }
156 
157   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
158   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
159 
160   /// Default layout to use in absence of the corresponding analysis.
161   DataLayout defaultLayout;
162 };
163 
164 // Out of line definition, required till C++17.
165 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
166 
167 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
168   AllocaOpLowering(LLVMTypeConverter &converter)
169       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
170                                 converter) {}
171 
172   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
173   /// is set to null for stack allocations. `accessAlignment` is set if
174   /// alignment is needed post allocation (for eg. in conjunction with malloc).
175   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
176                                           Location loc, Value sizeBytes,
177                                           Operation *op) const override {
178 
179     // With alloca, one gets a pointer to the element type right away.
180     // For stack allocations.
181     auto allocaOp = cast<memref::AllocaOp>(op);
182     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
183 
184     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
185         loc, elementPtrType, sizeBytes,
186         allocaOp.alignment() ? *allocaOp.alignment() : 0);
187 
188     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
189   }
190 };
191 
192 struct AllocaScopeOpLowering
193     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
194   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
195 
196   LogicalResult
197   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
198                   ConversionPatternRewriter &rewriter) const override {
199     OpBuilder::InsertionGuard guard(rewriter);
200     Location loc = allocaScopeOp.getLoc();
201 
202     // Split the current block before the AllocaScopeOp to create the inlining
203     // point.
204     auto *currentBlock = rewriter.getInsertionBlock();
205     auto *remainingOpsBlock =
206         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
207     Block *continueBlock;
208     if (allocaScopeOp.getNumResults() == 0) {
209       continueBlock = remainingOpsBlock;
210     } else {
211       continueBlock = rewriter.createBlock(
212           remainingOpsBlock, allocaScopeOp.getResultTypes(),
213           SmallVector<Location>(allocaScopeOp->getNumResults(),
214                                 allocaScopeOp.getLoc()));
215       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
216     }
217 
218     // Inline body region.
219     Block *beforeBody = &allocaScopeOp.bodyRegion().front();
220     Block *afterBody = &allocaScopeOp.bodyRegion().back();
221     rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
222 
223     // Save stack and then branch into the body of the region.
224     rewriter.setInsertionPointToEnd(currentBlock);
225     auto stackSaveOp =
226         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
227     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
228 
229     // Replace the alloca_scope return with a branch that jumps out of the body.
230     // Stack restore before leaving the body region.
231     rewriter.setInsertionPointToEnd(afterBody);
232     auto returnOp =
233         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
234     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
235         returnOp, returnOp.results(), continueBlock);
236 
237     // Insert stack restore before jumping out the body of the region.
238     rewriter.setInsertionPoint(branchOp);
239     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
240 
241     // Replace the op with values return from the body region.
242     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
243 
244     return success();
245   }
246 };
247 
248 struct AssumeAlignmentOpLowering
249     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
250   using ConvertOpToLLVMPattern<
251       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
252 
253   LogicalResult
254   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
255                   ConversionPatternRewriter &rewriter) const override {
256     Value memref = adaptor.memref();
257     unsigned alignment = op.alignment();
258     auto loc = op.getLoc();
259 
260     MemRefDescriptor memRefDescriptor(memref);
261     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
262 
263     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
264     // the asserted memref.alignedPtr isn't used anywhere else, as the real
265     // users like load/store/views always re-extract memref.alignedPtr as they
266     // get lowered.
267     //
268     // This relies on LLVM's CSE optimization (potentially after SROA), since
269     // after CSE all memref.alignedPtr instances get de-duplicated into the same
270     // pointer SSA value.
271     auto intPtrType =
272         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
273     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
274     Value mask =
275         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
276     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
277     rewriter.create<LLVM::AssumeOp>(
278         loc, rewriter.create<LLVM::ICmpOp>(
279                  loc, LLVM::ICmpPredicate::eq,
280                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
281 
282     rewriter.eraseOp(op);
283     return success();
284   }
285 };
286 
287 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
288 // The memref descriptor being an SSA value, there is no need to clean it up
289 // in any way.
290 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
291   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
292 
293   explicit DeallocOpLowering(LLVMTypeConverter &converter)
294       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
295 
296   LogicalResult
297   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
298                   ConversionPatternRewriter &rewriter) const override {
299     // Insert the `free` declaration if it is not already present.
300     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
301     MemRefDescriptor memref(adaptor.memref());
302     Value casted = rewriter.create<LLVM::BitcastOp>(
303         op.getLoc(), getVoidPtrType(),
304         memref.allocatedPtr(rewriter, op.getLoc()));
305     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
306         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
307     return success();
308   }
309 };
310 
311 // A `dim` is converted to a constant for static sizes and to an access to the
312 // size stored in the memref descriptor for dynamic sizes.
313 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
314   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
315 
316   LogicalResult
317   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
318                   ConversionPatternRewriter &rewriter) const override {
319     Type operandType = dimOp.source().getType();
320     if (operandType.isa<UnrankedMemRefType>()) {
321       rewriter.replaceOp(
322           dimOp, {extractSizeOfUnrankedMemRef(
323                      operandType, dimOp, adaptor.getOperands(), rewriter)});
324 
325       return success();
326     }
327     if (operandType.isa<MemRefType>()) {
328       rewriter.replaceOp(
329           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
330                                             adaptor.getOperands(), rewriter)});
331       return success();
332     }
333     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
334   }
335 
336 private:
337   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
338                                     OpAdaptor adaptor,
339                                     ConversionPatternRewriter &rewriter) const {
340     Location loc = dimOp.getLoc();
341 
342     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
343     auto scalarMemRefType =
344         MemRefType::get({}, unrankedMemRefType.getElementType());
345     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
346 
347     // Extract pointer to the underlying ranked descriptor and bitcast it to a
348     // memref<element_type> descriptor pointer to minimize the number of GEP
349     // operations.
350     UnrankedMemRefDescriptor unrankedDesc(adaptor.source());
351     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
352     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
353         loc,
354         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
355                                    addressSpace),
356         underlyingRankedDesc);
357 
358     // Get pointer to offset field of memref<element_type> descriptor.
359     Type indexPtrTy = LLVM::LLVMPointerType::get(
360         getTypeConverter()->getIndexType(), addressSpace);
361     Value two = rewriter.create<LLVM::ConstantOp>(
362         loc, typeConverter->convertType(rewriter.getI32Type()),
363         rewriter.getI32IntegerAttr(2));
364     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
365         loc, indexPtrTy, scalarMemRefDescPtr,
366         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
367 
368     // The size value that we have to extract can be obtained using GEPop with
369     // `dimOp.index() + 1` index argument.
370     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
371         loc, createIndexConstant(rewriter, loc, 1), adaptor.index());
372     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
373                                                  ValueRange({idxPlusOne}));
374     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
375   }
376 
377   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
378     if (Optional<int64_t> idx = dimOp.getConstantIndex())
379       return idx;
380 
381     if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>())
382       return constantOp.getValue()
383           .cast<IntegerAttr>()
384           .getValue()
385           .getSExtValue();
386 
387     return llvm::None;
388   }
389 
390   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
391                                   OpAdaptor adaptor,
392                                   ConversionPatternRewriter &rewriter) const {
393     Location loc = dimOp.getLoc();
394 
395     // Take advantage if index is constant.
396     MemRefType memRefType = operandType.cast<MemRefType>();
397     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
398       int64_t i = index.getValue();
399       if (memRefType.isDynamicDim(i)) {
400         // extract dynamic size from the memref descriptor.
401         MemRefDescriptor descriptor(adaptor.source());
402         return descriptor.size(rewriter, loc, i);
403       }
404       // Use constant for static size.
405       int64_t dimSize = memRefType.getDimSize(i);
406       return createIndexConstant(rewriter, loc, dimSize);
407     }
408     Value index = adaptor.index();
409     int64_t rank = memRefType.getRank();
410     MemRefDescriptor memrefDescriptor(adaptor.source());
411     return memrefDescriptor.size(rewriter, loc, index, rank);
412   }
413 };
414 
415 /// Common base for load and store operations on MemRefs. Restricts the match
416 /// to supported MemRef types. Provides functionality to emit code accessing a
417 /// specific element of the underlying data buffer.
418 template <typename Derived>
419 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
420   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
421   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
422   using Base = LoadStoreOpLowering<Derived>;
423 
424   LogicalResult match(Derived op) const override {
425     MemRefType type = op.getMemRefType();
426     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
427   }
428 };
429 
430 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
431 /// retried until it succeeds in atomically storing a new value into memory.
432 ///
433 ///      +---------------------------------+
434 ///      |   <code before the AtomicRMWOp> |
435 ///      |   <compute initial %loaded>     |
436 ///      |   br loop(%loaded)              |
437 ///      +---------------------------------+
438 ///             |
439 ///  -------|   |
440 ///  |      v   v
441 ///  |   +--------------------------------+
442 ///  |   | loop(%loaded):                 |
443 ///  |   |   <body contents>              |
444 ///  |   |   %pair = cmpxchg              |
445 ///  |   |   %ok = %pair[0]               |
446 ///  |   |   %new = %pair[1]              |
447 ///  |   |   cond_br %ok, end, loop(%new) |
448 ///  |   +--------------------------------+
449 ///  |          |        |
450 ///  |-----------        |
451 ///                      v
452 ///      +--------------------------------+
453 ///      | end:                           |
454 ///      |   <code after the AtomicRMWOp> |
455 ///      +--------------------------------+
456 ///
457 struct GenericAtomicRMWOpLowering
458     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
459   using Base::Base;
460 
461   LogicalResult
462   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
463                   ConversionPatternRewriter &rewriter) const override {
464     auto loc = atomicOp.getLoc();
465     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
466 
467     // Split the block into initial, loop, and ending parts.
468     auto *initBlock = rewriter.getInsertionBlock();
469     auto *loopBlock = rewriter.createBlock(
470         initBlock->getParent(), std::next(Region::iterator(initBlock)),
471         valueType, loc);
472     auto *endBlock = rewriter.createBlock(
473         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
474 
475     // Operations range to be moved to `endBlock`.
476     auto opsToMoveStart = atomicOp->getIterator();
477     auto opsToMoveEnd = initBlock->back().getIterator();
478 
479     // Compute the loaded value and branch to the loop block.
480     rewriter.setInsertionPointToEnd(initBlock);
481     auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
482     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
483                                         adaptor.indices(), rewriter);
484     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
485     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
486 
487     // Prepare the body of the loop block.
488     rewriter.setInsertionPointToStart(loopBlock);
489 
490     // Clone the GenericAtomicRMWOp region and extract the result.
491     auto loopArgument = loopBlock->getArgument(0);
492     BlockAndValueMapping mapping;
493     mapping.map(atomicOp.getCurrentValue(), loopArgument);
494     Block &entryBlock = atomicOp.body().front();
495     for (auto &nestedOp : entryBlock.without_terminator()) {
496       Operation *clone = rewriter.clone(nestedOp, mapping);
497       mapping.map(nestedOp.getResults(), clone->getResults());
498     }
499     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
500 
501     // Prepare the epilog of the loop block.
502     // Append the cmpxchg op to the end of the loop block.
503     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
504     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
505     auto boolType = IntegerType::get(rewriter.getContext(), 1);
506     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
507                                                      {valueType, boolType});
508     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
509         loc, pairType, dataPtr, loopArgument, result, successOrdering,
510         failureOrdering);
511     // Extract the %new_loaded and %ok values from the pair.
512     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
513         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
514     Value ok = rewriter.create<LLVM::ExtractValueOp>(
515         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
516 
517     // Conditionally branch to the end or back to the loop depending on %ok.
518     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
519                                     loopBlock, newLoaded);
520 
521     rewriter.setInsertionPointToEnd(endBlock);
522     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
523                  std::next(opsToMoveEnd), rewriter);
524 
525     // The 'result' of the atomic_rmw op is the newly loaded value.
526     rewriter.replaceOp(atomicOp, {newLoaded});
527 
528     return success();
529   }
530 
531 private:
532   // Clones a segment of ops [start, end) and erases the original.
533   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
534                     Block::iterator start, Block::iterator end,
535                     ConversionPatternRewriter &rewriter) const {
536     BlockAndValueMapping mapping;
537     mapping.map(oldResult, newResult);
538     SmallVector<Operation *, 2> opsToErase;
539     for (auto it = start; it != end; ++it) {
540       rewriter.clone(*it, mapping);
541       opsToErase.push_back(&*it);
542     }
543     for (auto *it : opsToErase)
544       rewriter.eraseOp(it);
545   }
546 };
547 
548 /// Returns the LLVM type of the global variable given the memref type `type`.
549 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
550                                           LLVMTypeConverter &typeConverter) {
551   // LLVM type for a global memref will be a multi-dimension array. For
552   // declarations or uninitialized global memrefs, we can potentially flatten
553   // this to a 1D array. However, for memref.global's with an initial value,
554   // we do not intend to flatten the ElementsAttribute when going from std ->
555   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
556   Type elementType = typeConverter.convertType(type.getElementType());
557   Type arrayTy = elementType;
558   // Shape has the outermost dim at index 0, so need to walk it backwards
559   for (int64_t dim : llvm::reverse(type.getShape()))
560     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
561   return arrayTy;
562 }
563 
564 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
565 struct GlobalMemrefOpLowering
566     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
567   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
568 
569   LogicalResult
570   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
571                   ConversionPatternRewriter &rewriter) const override {
572     MemRefType type = global.type();
573     if (!isConvertibleAndHasIdentityMaps(type))
574       return failure();
575 
576     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
577 
578     LLVM::Linkage linkage =
579         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
580 
581     Attribute initialValue = nullptr;
582     if (!global.isExternal() && !global.isUninitialized()) {
583       auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
584       initialValue = elementsAttr;
585 
586       // For scalar memrefs, the global variable created is of the element type,
587       // so unpack the elements attribute to extract the value.
588       if (type.getRank() == 0)
589         initialValue = elementsAttr.getSplatValue<Attribute>();
590     }
591 
592     uint64_t alignment = global.alignment().getValueOr(0);
593 
594     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
595         global, arrayTy, global.constant(), linkage, global.sym_name(),
596         initialValue, alignment, type.getMemorySpaceAsInt());
597     if (!global.isExternal() && global.isUninitialized()) {
598       Block *blk = new Block();
599       newGlobal.getInitializerRegion().push_back(blk);
600       rewriter.setInsertionPointToStart(blk);
601       Value undef[] = {
602           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
603       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
604     }
605     return success();
606   }
607 };
608 
609 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
610 /// the first element stashed into the descriptor. This reuses
611 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
612 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
613   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
614       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
615                                 converter) {}
616 
617   /// Buffer "allocation" for memref.get_global op is getting the address of
618   /// the global variable referenced.
619   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
620                                           Location loc, Value sizeBytes,
621                                           Operation *op) const override {
622     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
623     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
624     unsigned memSpace = type.getMemorySpaceAsInt();
625 
626     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
627     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
628         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
629 
630     // Get the address of the first element in the array by creating a GEP with
631     // the address of the GV as the base, and (rank + 1) number of 0 indices.
632     Type elementType = typeConverter->convertType(type.getElementType());
633     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
634 
635     SmallVector<Value> operands;
636     operands.insert(operands.end(), type.getRank() + 1,
637                     createIndexConstant(rewriter, loc, 0));
638     auto gep =
639         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
640 
641     // We do not expect the memref obtained using `memref.get_global` to be
642     // ever deallocated. Set the allocated pointer to be known bad value to
643     // help debug if that ever happens.
644     auto intPtrType = getIntPtrType(memSpace);
645     Value deadBeefConst =
646         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
647     auto deadBeefPtr =
648         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
649 
650     // Both allocated and aligned pointers are same. We could potentially stash
651     // a nullptr for the allocated pointer since we do not expect any dealloc.
652     return std::make_tuple(deadBeefPtr, gep);
653   }
654 };
655 
656 // Load operation is lowered to obtaining a pointer to the indexed element
657 // and loading it.
658 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
659   using Base::Base;
660 
661   LogicalResult
662   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
663                   ConversionPatternRewriter &rewriter) const override {
664     auto type = loadOp.getMemRefType();
665 
666     Value dataPtr = getStridedElementPtr(
667         loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter);
668     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
669     return success();
670   }
671 };
672 
673 // Store operation is lowered to obtaining a pointer to the indexed element,
674 // and storing the given value to it.
675 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
676   using Base::Base;
677 
678   LogicalResult
679   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
680                   ConversionPatternRewriter &rewriter) const override {
681     auto type = op.getMemRefType();
682 
683     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(),
684                                          adaptor.indices(), rewriter);
685     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr);
686     return success();
687   }
688 };
689 
690 // The prefetch operation is lowered in a way similar to the load operation
691 // except that the llvm.prefetch operation is used for replacement.
692 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
693   using Base::Base;
694 
695   LogicalResult
696   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
697                   ConversionPatternRewriter &rewriter) const override {
698     auto type = prefetchOp.getMemRefType();
699     auto loc = prefetchOp.getLoc();
700 
701     Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(),
702                                          adaptor.indices(), rewriter);
703 
704     // Replace with llvm.prefetch.
705     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
706     auto isWrite = rewriter.create<LLVM::ConstantOp>(
707         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
708     auto localityHint = rewriter.create<LLVM::ConstantOp>(
709         loc, llvmI32Type,
710         rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
711     auto isData = rewriter.create<LLVM::ConstantOp>(
712         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
713 
714     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
715                                                 localityHint, isData);
716     return success();
717   }
718 };
719 
720 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
721   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
722 
723   LogicalResult
724   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
725                   ConversionPatternRewriter &rewriter) const override {
726     Location loc = op.getLoc();
727     Type operandType = op.memref().getType();
728     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
729       UnrankedMemRefDescriptor desc(adaptor.memref());
730       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
731       return success();
732     }
733     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
734       rewriter.replaceOp(
735           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
736       return success();
737     }
738     return failure();
739   }
740 };
741 
742 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
743   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
744 
745   LogicalResult match(memref::CastOp memRefCastOp) const override {
746     Type srcType = memRefCastOp.getOperand().getType();
747     Type dstType = memRefCastOp.getType();
748 
749     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
750     // used for type erasure. For now they must preserve underlying element type
751     // and require source and result type to have the same rank. Therefore,
752     // perform a sanity check that the underlying structs are the same. Once op
753     // semantics are relaxed we can revisit.
754     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
755       return success(typeConverter->convertType(srcType) ==
756                      typeConverter->convertType(dstType));
757 
758     // At least one of the operands is unranked type
759     assert(srcType.isa<UnrankedMemRefType>() ||
760            dstType.isa<UnrankedMemRefType>());
761 
762     // Unranked to unranked cast is disallowed
763     return !(srcType.isa<UnrankedMemRefType>() &&
764              dstType.isa<UnrankedMemRefType>())
765                ? success()
766                : failure();
767   }
768 
769   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
770                ConversionPatternRewriter &rewriter) const override {
771     auto srcType = memRefCastOp.getOperand().getType();
772     auto dstType = memRefCastOp.getType();
773     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
774     auto loc = memRefCastOp.getLoc();
775 
776     // For ranked/ranked case, just keep the original descriptor.
777     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
778       return rewriter.replaceOp(memRefCastOp, {adaptor.source()});
779 
780     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
781       // Casting ranked to unranked memref type
782       // Set the rank in the destination from the memref type
783       // Allocate space on the stack and copy the src memref descriptor
784       // Set the ptr in the destination to the stack space
785       auto srcMemRefType = srcType.cast<MemRefType>();
786       int64_t rank = srcMemRefType.getRank();
787       // ptr = AllocaOp sizeof(MemRefDescriptor)
788       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
789           loc, adaptor.source(), rewriter);
790       // voidptr = BitCastOp srcType* to void*
791       auto voidPtr =
792           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
793               .getResult();
794       // rank = ConstantOp srcRank
795       auto rankVal = rewriter.create<LLVM::ConstantOp>(
796           loc, typeConverter->convertType(rewriter.getIntegerType(64)),
797           rewriter.getI64IntegerAttr(rank));
798       // undef = UndefOp
799       UnrankedMemRefDescriptor memRefDesc =
800           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
801       // d1 = InsertValueOp undef, rank, 0
802       memRefDesc.setRank(rewriter, loc, rankVal);
803       // d2 = InsertValueOp d1, voidptr, 1
804       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
805       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
806 
807     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
808       // Casting from unranked type to ranked.
809       // The operation is assumed to be doing a correct cast. If the destination
810       // type mismatches the unranked the type, it is undefined behavior.
811       UnrankedMemRefDescriptor memRefDesc(adaptor.source());
812       // ptr = ExtractValueOp src, 1
813       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
814       // castPtr = BitCastOp i8* to structTy*
815       auto castPtr =
816           rewriter
817               .create<LLVM::BitcastOp>(
818                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
819               .getResult();
820       // struct = LoadOp castPtr
821       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
822       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
823     } else {
824       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
825     }
826   }
827 };
828 
829 /// Pattern to lower a `memref.copy` to llvm.
830 ///
831 /// For memrefs with identity layouts, the copy is lowered to the llvm
832 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
833 /// to the generic `MemrefCopyFn`.
834 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
835   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
836 
837   LogicalResult
838   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
839                           ConversionPatternRewriter &rewriter) const {
840     auto loc = op.getLoc();
841     auto srcType = op.source().getType().dyn_cast<MemRefType>();
842 
843     MemRefDescriptor srcDesc(adaptor.source());
844 
845     // Compute number of elements.
846     Value numElements = rewriter.create<LLVM::ConstantOp>(
847         loc, getIndexType(), rewriter.getIndexAttr(1));
848     for (int pos = 0; pos < srcType.getRank(); ++pos) {
849       auto size = srcDesc.size(rewriter, loc, pos);
850       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
851     }
852 
853     // Get element size.
854     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
855     // Compute total.
856     Value totalSize =
857         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
858 
859     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
860     MemRefDescriptor targetDesc(adaptor.target());
861     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
862     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
863         loc, typeConverter->convertType(rewriter.getI1Type()),
864         rewriter.getBoolAttr(false));
865     rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize,
866                                     isVolatile);
867     rewriter.eraseOp(op);
868 
869     return success();
870   }
871 
872   LogicalResult
873   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
874                              ConversionPatternRewriter &rewriter) const {
875     auto loc = op.getLoc();
876     auto srcType = op.source().getType().cast<BaseMemRefType>();
877     auto targetType = op.target().getType().cast<BaseMemRefType>();
878 
879     // First make sure we have an unranked memref descriptor representation.
880     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
881       auto rank = rewriter.create<LLVM::ConstantOp>(
882           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
883       auto *typeConverter = getTypeConverter();
884       auto ptr =
885           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
886       auto voidPtr =
887           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
888               .getResult();
889       auto unrankedType =
890           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
891       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
892                                             unrankedType,
893                                             ValueRange{rank, voidPtr});
894     };
895 
896     Value unrankedSource = srcType.hasRank()
897                                ? makeUnranked(adaptor.source(), srcType)
898                                : adaptor.source();
899     Value unrankedTarget = targetType.hasRank()
900                                ? makeUnranked(adaptor.target(), targetType)
901                                : adaptor.target();
902 
903     // Now promote the unranked descriptors to the stack.
904     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
905                                                  rewriter.getIndexAttr(1));
906     auto promote = [&](Value desc) {
907       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
908       auto allocated =
909           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
910       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
911       return allocated;
912     };
913 
914     auto sourcePtr = promote(unrankedSource);
915     auto targetPtr = promote(unrankedTarget);
916 
917     auto elemSize = rewriter.create<LLVM::ConstantOp>(
918         loc, getIndexType(),
919         rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8));
920     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
921         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
922     rewriter.create<LLVM::CallOp>(loc, copyFn,
923                                   ValueRange{elemSize, sourcePtr, targetPtr});
924     rewriter.eraseOp(op);
925 
926     return success();
927   }
928 
929   LogicalResult
930   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
931                   ConversionPatternRewriter &rewriter) const override {
932     auto srcType = op.source().getType().cast<BaseMemRefType>();
933     auto targetType = op.target().getType().cast<BaseMemRefType>();
934 
935     if (srcType.hasRank() &&
936         srcType.cast<MemRefType>().getLayout().isIdentity() &&
937         targetType.hasRank() &&
938         targetType.cast<MemRefType>().getLayout().isIdentity())
939       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
940 
941     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
942   }
943 };
944 
945 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
946 /// memref type. In unranked case, the fields are extracted from the underlying
947 /// ranked descriptor.
948 static void extractPointersAndOffset(Location loc,
949                                      ConversionPatternRewriter &rewriter,
950                                      LLVMTypeConverter &typeConverter,
951                                      Value originalOperand,
952                                      Value convertedOperand,
953                                      Value *allocatedPtr, Value *alignedPtr,
954                                      Value *offset = nullptr) {
955   Type operandType = originalOperand.getType();
956   if (operandType.isa<MemRefType>()) {
957     MemRefDescriptor desc(convertedOperand);
958     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
959     *alignedPtr = desc.alignedPtr(rewriter, loc);
960     if (offset != nullptr)
961       *offset = desc.offset(rewriter, loc);
962     return;
963   }
964 
965   unsigned memorySpace =
966       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
967   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
968   Type llvmElementType = typeConverter.convertType(elementType);
969   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
970       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
971 
972   // Extract pointer to the underlying ranked memref descriptor and cast it to
973   // ElemType**.
974   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
975   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
976 
977   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
978       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
979   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
980       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
981   if (offset != nullptr) {
982     *offset = UnrankedMemRefDescriptor::offset(
983         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
984   }
985 }
986 
987 struct MemRefReinterpretCastOpLowering
988     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
989   using ConvertOpToLLVMPattern<
990       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
991 
992   LogicalResult
993   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
994                   ConversionPatternRewriter &rewriter) const override {
995     Type srcType = castOp.source().getType();
996 
997     Value descriptor;
998     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
999                                                adaptor, &descriptor)))
1000       return failure();
1001     rewriter.replaceOp(castOp, {descriptor});
1002     return success();
1003   }
1004 
1005 private:
1006   LogicalResult convertSourceMemRefToDescriptor(
1007       ConversionPatternRewriter &rewriter, Type srcType,
1008       memref::ReinterpretCastOp castOp,
1009       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1010     MemRefType targetMemRefType =
1011         castOp.getResult().getType().cast<MemRefType>();
1012     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1013                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1014     if (!llvmTargetDescriptorTy)
1015       return failure();
1016 
1017     // Create descriptor.
1018     Location loc = castOp.getLoc();
1019     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1020 
1021     // Set allocated and aligned pointers.
1022     Value allocatedPtr, alignedPtr;
1023     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1024                              castOp.source(), adaptor.source(), &allocatedPtr,
1025                              &alignedPtr);
1026     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1027     desc.setAlignedPtr(rewriter, loc, alignedPtr);
1028 
1029     // Set offset.
1030     if (castOp.isDynamicOffset(0))
1031       desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
1032     else
1033       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1034 
1035     // Set sizes and strides.
1036     unsigned dynSizeId = 0;
1037     unsigned dynStrideId = 0;
1038     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1039       if (castOp.isDynamicSize(i))
1040         desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
1041       else
1042         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1043 
1044       if (castOp.isDynamicStride(i))
1045         desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
1046       else
1047         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1048     }
1049     *descriptor = desc;
1050     return success();
1051   }
1052 };
1053 
1054 struct MemRefReshapeOpLowering
1055     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1056   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1057 
1058   LogicalResult
1059   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1060                   ConversionPatternRewriter &rewriter) const override {
1061     Type srcType = reshapeOp.source().getType();
1062 
1063     Value descriptor;
1064     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1065                                                adaptor, &descriptor)))
1066       return failure();
1067     rewriter.replaceOp(reshapeOp, {descriptor});
1068     return success();
1069   }
1070 
1071 private:
1072   LogicalResult
1073   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1074                                   Type srcType, memref::ReshapeOp reshapeOp,
1075                                   memref::ReshapeOp::Adaptor adaptor,
1076                                   Value *descriptor) const {
1077     // Conversion for statically-known shape args is performed via
1078     // `memref_reinterpret_cast`.
1079     auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
1080     if (shapeMemRefType.hasStaticShape())
1081       return failure();
1082 
1083     // The shape is a rank-1 tensor with unknown length.
1084     Location loc = reshapeOp.getLoc();
1085     MemRefDescriptor shapeDesc(adaptor.shape());
1086     Value resultRank = shapeDesc.size(rewriter, loc, 0);
1087 
1088     // Extract address space and element type.
1089     auto targetType =
1090         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1091     unsigned addressSpace = targetType.getMemorySpaceAsInt();
1092     Type elementType = targetType.getElementType();
1093 
1094     // Create the unranked memref descriptor that holds the ranked one. The
1095     // inner descriptor is allocated on stack.
1096     auto targetDesc = UnrankedMemRefDescriptor::undef(
1097         rewriter, loc, typeConverter->convertType(targetType));
1098     targetDesc.setRank(rewriter, loc, resultRank);
1099     SmallVector<Value, 4> sizes;
1100     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1101                                            targetDesc, sizes);
1102     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1103         loc, getVoidPtrType(), sizes.front(), llvm::None);
1104     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1105 
1106     // Extract pointers and offset from the source memref.
1107     Value allocatedPtr, alignedPtr, offset;
1108     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1109                              reshapeOp.source(), adaptor.source(),
1110                              &allocatedPtr, &alignedPtr, &offset);
1111 
1112     // Set pointers and offset.
1113     Type llvmElementType = typeConverter->convertType(elementType);
1114     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1115         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1116     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1117                                               elementPtrPtrType, allocatedPtr);
1118     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1119                                             underlyingDescPtr,
1120                                             elementPtrPtrType, alignedPtr);
1121     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1122                                         underlyingDescPtr, elementPtrPtrType,
1123                                         offset);
1124 
1125     // Use the offset pointer as base for further addressing. Copy over the new
1126     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1127     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1128         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1129         elementPtrPtrType);
1130     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1131         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1132     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1133     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1134     Value resultRankMinusOne =
1135         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1136 
1137     Block *initBlock = rewriter.getInsertionBlock();
1138     Type indexType = getTypeConverter()->getIndexType();
1139     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1140 
1141     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1142                                             {indexType, indexType}, {loc, loc});
1143 
1144     // Move the remaining initBlock ops to condBlock.
1145     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1146     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1147 
1148     rewriter.setInsertionPointToEnd(initBlock);
1149     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1150                                 condBlock);
1151     rewriter.setInsertionPointToStart(condBlock);
1152     Value indexArg = condBlock->getArgument(0);
1153     Value strideArg = condBlock->getArgument(1);
1154 
1155     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1156     Value pred = rewriter.create<LLVM::ICmpOp>(
1157         loc, IntegerType::get(rewriter.getContext(), 1),
1158         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1159 
1160     Block *bodyBlock =
1161         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1162     rewriter.setInsertionPointToStart(bodyBlock);
1163 
1164     // Copy size from shape to descriptor.
1165     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1166     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1167         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1168     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1169     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1170                                       targetSizesBase, indexArg, size);
1171 
1172     // Write stride value and compute next one.
1173     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1174                                         targetStridesBase, indexArg, strideArg);
1175     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1176 
1177     // Decrement loop counter and branch back.
1178     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1179     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1180                                 condBlock);
1181 
1182     Block *remainder =
1183         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1184 
1185     // Hook up the cond exit to the remainder.
1186     rewriter.setInsertionPointToEnd(condBlock);
1187     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1188                                     llvm::None);
1189 
1190     // Reset position to beginning of new remainder block.
1191     rewriter.setInsertionPointToStart(remainder);
1192 
1193     *descriptor = targetDesc;
1194     return success();
1195   }
1196 };
1197 
1198 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1199 /// `Value`s.
1200 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1201                                       Type &llvmIndexType,
1202                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1203   return llvm::to_vector<4>(
1204       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1205         if (auto attr = value.dyn_cast<Attribute>())
1206           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1207         return value.get<Value>();
1208       }));
1209 }
1210 
1211 /// Compute a map that for a given dimension of the expanded type gives the
1212 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1213 /// the `reassocation` maps.
1214 static DenseMap<int64_t, int64_t>
1215 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1216   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1217   for (auto &en : enumerate(reassociation)) {
1218     for (auto dim : en.value())
1219       expandedDimToCollapsedDim[dim] = en.index();
1220   }
1221   return expandedDimToCollapsedDim;
1222 }
1223 
1224 static OpFoldResult
1225 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1226                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1227                          MemRefDescriptor &inDesc,
1228                          ArrayRef<int64_t> inStaticShape,
1229                          ArrayRef<ReassociationIndices> reassocation,
1230                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1231   int64_t outDimSize = outStaticShape[outDimIndex];
1232   if (!ShapedType::isDynamic(outDimSize))
1233     return b.getIndexAttr(outDimSize);
1234 
1235   // Calculate the multiplication of all the out dim sizes except the
1236   // current dim.
1237   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1238   int64_t otherDimSizesMul = 1;
1239   for (auto otherDimIndex : reassocation[inDimIndex]) {
1240     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1241       continue;
1242     int64_t otherDimSize = outStaticShape[otherDimIndex];
1243     assert(!ShapedType::isDynamic(otherDimSize) &&
1244            "single dimension cannot be expanded into multiple dynamic "
1245            "dimensions");
1246     otherDimSizesMul *= otherDimSize;
1247   }
1248 
1249   // outDimSize = inDimSize / otherOutDimSizesMul
1250   int64_t inDimSize = inStaticShape[inDimIndex];
1251   Value inDimSizeDynamic =
1252       ShapedType::isDynamic(inDimSize)
1253           ? inDesc.size(b, loc, inDimIndex)
1254           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1255                                        b.getIndexAttr(inDimSize));
1256   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1257       loc, inDimSizeDynamic,
1258       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1259                                  b.getIndexAttr(otherDimSizesMul)));
1260   return outDimSizeDynamic;
1261 }
1262 
1263 static OpFoldResult getCollapsedOutputDimSize(
1264     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1265     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1266     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1267   if (!ShapedType::isDynamic(outDimSize))
1268     return b.getIndexAttr(outDimSize);
1269 
1270   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1271   Value outDimSizeDynamic = c1;
1272   for (auto inDimIndex : reassocation[outDimIndex]) {
1273     int64_t inDimSize = inStaticShape[inDimIndex];
1274     Value inDimSizeDynamic =
1275         ShapedType::isDynamic(inDimSize)
1276             ? inDesc.size(b, loc, inDimIndex)
1277             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1278                                          b.getIndexAttr(inDimSize));
1279     outDimSizeDynamic =
1280         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1281   }
1282   return outDimSizeDynamic;
1283 }
1284 
1285 static SmallVector<OpFoldResult, 4>
1286 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1287                         ArrayRef<ReassociationIndices> reassocation,
1288                         ArrayRef<int64_t> inStaticShape,
1289                         MemRefDescriptor &inDesc,
1290                         ArrayRef<int64_t> outStaticShape) {
1291   return llvm::to_vector<4>(llvm::map_range(
1292       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1293         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1294                                          outStaticShape[outDimIndex],
1295                                          inStaticShape, inDesc, reassocation);
1296       }));
1297 }
1298 
1299 static SmallVector<OpFoldResult, 4>
1300 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1301                        ArrayRef<ReassociationIndices> reassocation,
1302                        ArrayRef<int64_t> inStaticShape,
1303                        MemRefDescriptor &inDesc,
1304                        ArrayRef<int64_t> outStaticShape) {
1305   DenseMap<int64_t, int64_t> outDimToInDimMap =
1306       getExpandedDimToCollapsedDimMap(reassocation);
1307   return llvm::to_vector<4>(llvm::map_range(
1308       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1309         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1310                                         outStaticShape, inDesc, inStaticShape,
1311                                         reassocation, outDimToInDimMap);
1312       }));
1313 }
1314 
1315 static SmallVector<Value>
1316 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1317                       ArrayRef<ReassociationIndices> reassocation,
1318                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1319                       ArrayRef<int64_t> outStaticShape) {
1320   return outStaticShape.size() < inStaticShape.size()
1321              ? getAsValues(b, loc, llvmIndexType,
1322                            getCollapsedOutputShape(b, loc, llvmIndexType,
1323                                                    reassocation, inStaticShape,
1324                                                    inDesc, outStaticShape))
1325              : getAsValues(b, loc, llvmIndexType,
1326                            getExpandedOutputShape(b, loc, llvmIndexType,
1327                                                   reassocation, inStaticShape,
1328                                                   inDesc, outStaticShape));
1329 }
1330 
1331 // ReshapeOp creates a new view descriptor of the proper rank.
1332 // For now, the only conversion supported is for target MemRef with static sizes
1333 // and strides.
1334 template <typename ReshapeOp>
1335 class ReassociatingReshapeOpConversion
1336     : public ConvertOpToLLVMPattern<ReshapeOp> {
1337 public:
1338   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1339   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1340 
1341   LogicalResult
1342   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1343                   ConversionPatternRewriter &rewriter) const override {
1344     MemRefType dstType = reshapeOp.getResultType();
1345     MemRefType srcType = reshapeOp.getSrcType();
1346 
1347     // The condition on the layouts can be ignored when all shapes are static.
1348     if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) {
1349       if (!srcType.getLayout().isIdentity() ||
1350           !dstType.getLayout().isIdentity()) {
1351         return rewriter.notifyMatchFailure(
1352             reshapeOp, "only empty layout map is supported");
1353       }
1354     }
1355 
1356     int64_t offset;
1357     SmallVector<int64_t, 4> strides;
1358     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1359       return rewriter.notifyMatchFailure(
1360           reshapeOp, "failed to get stride and offset exprs");
1361     }
1362 
1363     MemRefDescriptor srcDesc(adaptor.src());
1364     Location loc = reshapeOp->getLoc();
1365     auto dstDesc = MemRefDescriptor::undef(
1366         rewriter, loc, this->typeConverter->convertType(dstType));
1367     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1368     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1369     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1370 
1371     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1372     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1373     Type llvmIndexType =
1374         this->typeConverter->convertType(rewriter.getIndexType());
1375     SmallVector<Value> dstShape = getDynamicOutputShape(
1376         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1377         srcStaticShape, srcDesc, dstStaticShape);
1378     for (auto &en : llvm::enumerate(dstShape))
1379       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1380 
1381     auto isStaticStride = [](int64_t stride) {
1382       return !ShapedType::isDynamicStrideOrOffset(stride);
1383     };
1384     if (llvm::all_of(strides, isStaticStride)) {
1385       for (auto &en : llvm::enumerate(strides))
1386         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1387     } else {
1388       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1389                                                    rewriter.getIndexAttr(1));
1390       Value stride = c1;
1391       for (auto dimIndex :
1392            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1393         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1394         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1395       }
1396     }
1397     rewriter.replaceOp(reshapeOp, {dstDesc});
1398     return success();
1399   }
1400 };
1401 
1402 /// Conversion pattern that transforms a subview op into:
1403 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1404 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1405 ///      and stride.
1406 /// The subview op is replaced by the descriptor.
1407 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1408   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1409 
1410   LogicalResult
1411   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1412                   ConversionPatternRewriter &rewriter) const override {
1413     auto loc = subViewOp.getLoc();
1414 
1415     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
1416     auto sourceElementTy =
1417         typeConverter->convertType(sourceMemRefType.getElementType());
1418 
1419     auto viewMemRefType = subViewOp.getType();
1420     auto inferredType = memref::SubViewOp::inferResultType(
1421                             subViewOp.getSourceType(),
1422                             extractFromI64ArrayAttr(subViewOp.static_offsets()),
1423                             extractFromI64ArrayAttr(subViewOp.static_sizes()),
1424                             extractFromI64ArrayAttr(subViewOp.static_strides()))
1425                             .cast<MemRefType>();
1426     auto targetElementTy =
1427         typeConverter->convertType(viewMemRefType.getElementType());
1428     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1429     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1430         !LLVM::isCompatibleType(sourceElementTy) ||
1431         !LLVM::isCompatibleType(targetElementTy) ||
1432         !LLVM::isCompatibleType(targetDescTy))
1433       return failure();
1434 
1435     // Extract the offset and strides from the type.
1436     int64_t offset;
1437     SmallVector<int64_t, 4> strides;
1438     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1439     if (failed(successStrides))
1440       return failure();
1441 
1442     // Create the descriptor.
1443     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1444       return failure();
1445     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1446     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1447 
1448     // Copy the buffer pointer from the old descriptor to the new one.
1449     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1450     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1451         loc,
1452         LLVM::LLVMPointerType::get(targetElementTy,
1453                                    viewMemRefType.getMemorySpaceAsInt()),
1454         extracted);
1455     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1456 
1457     // Copy the aligned pointer from the old descriptor to the new one.
1458     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1459     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1460         loc,
1461         LLVM::LLVMPointerType::get(targetElementTy,
1462                                    viewMemRefType.getMemorySpaceAsInt()),
1463         extracted);
1464     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1465 
1466     size_t inferredShapeRank = inferredType.getRank();
1467     size_t resultShapeRank = viewMemRefType.getRank();
1468 
1469     // Extract strides needed to compute offset.
1470     SmallVector<Value, 4> strideValues;
1471     strideValues.reserve(inferredShapeRank);
1472     for (unsigned i = 0; i < inferredShapeRank; ++i)
1473       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1474 
1475     // Offset.
1476     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1477     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1478       targetMemRef.setConstantOffset(rewriter, loc, offset);
1479     } else {
1480       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1481       // `inferredShapeRank` may be larger than the number of offset operands
1482       // because of trailing semantics. In this case, the offset is guaranteed
1483       // to be interpreted as 0 and we can just skip the extra dimensions.
1484       for (unsigned i = 0, e = std::min(inferredShapeRank,
1485                                         subViewOp.getMixedOffsets().size());
1486            i < e; ++i) {
1487         Value offset =
1488             // TODO: need OpFoldResult ODS adaptor to clean this up.
1489             subViewOp.isDynamicOffset(i)
1490                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1491                 : rewriter.create<LLVM::ConstantOp>(
1492                       loc, llvmIndexType,
1493                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1494         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1495         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1496       }
1497       targetMemRef.setOffset(rewriter, loc, baseOffset);
1498     }
1499 
1500     // Update sizes and strides.
1501     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1502     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1503     assert(mixedSizes.size() == mixedStrides.size() &&
1504            "expected sizes and strides of equal length");
1505     llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
1506     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1507          i >= 0 && j >= 0; --i) {
1508       if (unusedDims.contains(i))
1509         continue;
1510 
1511       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1512       // In this case, the size is guaranteed to be interpreted as Dim and the
1513       // stride as 1.
1514       Value size, stride;
1515       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1516         // If the static size is available, use it directly. This is similar to
1517         // the folding of dim(constant-op) but removes the need for dim to be
1518         // aware of LLVM constants and for this pass to be aware of std
1519         // constants.
1520         int64_t staticSize =
1521             subViewOp.source().getType().cast<MemRefType>().getShape()[i];
1522         if (staticSize != ShapedType::kDynamicSize) {
1523           size = rewriter.create<LLVM::ConstantOp>(
1524               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1525         } else {
1526           Value pos = rewriter.create<LLVM::ConstantOp>(
1527               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1528           Value dim =
1529               rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos);
1530           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1531               loc, llvmIndexType, dim);
1532           size = cast.getResult(0);
1533         }
1534         stride = rewriter.create<LLVM::ConstantOp>(
1535             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1536       } else {
1537         // TODO: need OpFoldResult ODS adaptor to clean this up.
1538         size =
1539             subViewOp.isDynamicSize(i)
1540                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1541                 : rewriter.create<LLVM::ConstantOp>(
1542                       loc, llvmIndexType,
1543                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1544         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1545           stride = rewriter.create<LLVM::ConstantOp>(
1546               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1547         } else {
1548           stride =
1549               subViewOp.isDynamicStride(i)
1550                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1551                   : rewriter.create<LLVM::ConstantOp>(
1552                         loc, llvmIndexType,
1553                         rewriter.getI64IntegerAttr(
1554                             subViewOp.getStaticStride(i)));
1555           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1556         }
1557       }
1558       targetMemRef.setSize(rewriter, loc, j, size);
1559       targetMemRef.setStride(rewriter, loc, j, stride);
1560       j--;
1561     }
1562 
1563     rewriter.replaceOp(subViewOp, {targetMemRef});
1564     return success();
1565   }
1566 };
1567 
1568 /// Conversion pattern that transforms a transpose op into:
1569 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1570 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1571 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1572 ///      and stride. Size and stride are permutations of the original values.
1573 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1574 /// The transpose op is replaced by the alloca'ed pointer.
1575 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1576 public:
1577   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1578 
1579   LogicalResult
1580   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1581                   ConversionPatternRewriter &rewriter) const override {
1582     auto loc = transposeOp.getLoc();
1583     MemRefDescriptor viewMemRef(adaptor.in());
1584 
1585     // No permutation, early exit.
1586     if (transposeOp.permutation().isIdentity())
1587       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1588 
1589     auto targetMemRef = MemRefDescriptor::undef(
1590         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1591 
1592     // Copy the base and aligned pointers from the old descriptor to the new
1593     // one.
1594     targetMemRef.setAllocatedPtr(rewriter, loc,
1595                                  viewMemRef.allocatedPtr(rewriter, loc));
1596     targetMemRef.setAlignedPtr(rewriter, loc,
1597                                viewMemRef.alignedPtr(rewriter, loc));
1598 
1599     // Copy the offset pointer from the old descriptor to the new one.
1600     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1601 
1602     // Iterate over the dimensions and apply size/stride permutation.
1603     for (const auto &en :
1604          llvm::enumerate(transposeOp.permutation().getResults())) {
1605       int sourcePos = en.index();
1606       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1607       targetMemRef.setSize(rewriter, loc, targetPos,
1608                            viewMemRef.size(rewriter, loc, sourcePos));
1609       targetMemRef.setStride(rewriter, loc, targetPos,
1610                              viewMemRef.stride(rewriter, loc, sourcePos));
1611     }
1612 
1613     rewriter.replaceOp(transposeOp, {targetMemRef});
1614     return success();
1615   }
1616 };
1617 
1618 /// Conversion pattern that transforms an op into:
1619 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1620 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1621 ///      and stride.
1622 /// The view op is replaced by the descriptor.
1623 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1624   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1625 
1626   // Build and return the value for the idx^th shape dimension, either by
1627   // returning the constant shape dimension or counting the proper dynamic size.
1628   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1629                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1630                 unsigned idx) const {
1631     assert(idx < shape.size());
1632     if (!ShapedType::isDynamic(shape[idx]))
1633       return createIndexConstant(rewriter, loc, shape[idx]);
1634     // Count the number of dynamic dims in range [0, idx]
1635     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1636       return ShapedType::isDynamic(v);
1637     });
1638     return dynamicSizes[nDynamic];
1639   }
1640 
1641   // Build and return the idx^th stride, either by returning the constant stride
1642   // or by computing the dynamic stride from the current `runningStride` and
1643   // `nextSize`. The caller should keep a running stride and update it with the
1644   // result returned by this function.
1645   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1646                   ArrayRef<int64_t> strides, Value nextSize,
1647                   Value runningStride, unsigned idx) const {
1648     assert(idx < strides.size());
1649     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1650       return createIndexConstant(rewriter, loc, strides[idx]);
1651     if (nextSize)
1652       return runningStride
1653                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1654                  : nextSize;
1655     assert(!runningStride);
1656     return createIndexConstant(rewriter, loc, 1);
1657   }
1658 
1659   LogicalResult
1660   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1661                   ConversionPatternRewriter &rewriter) const override {
1662     auto loc = viewOp.getLoc();
1663 
1664     auto viewMemRefType = viewOp.getType();
1665     auto targetElementTy =
1666         typeConverter->convertType(viewMemRefType.getElementType());
1667     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1668     if (!targetDescTy || !targetElementTy ||
1669         !LLVM::isCompatibleType(targetElementTy) ||
1670         !LLVM::isCompatibleType(targetDescTy))
1671       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1672              failure();
1673 
1674     int64_t offset;
1675     SmallVector<int64_t, 4> strides;
1676     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1677     if (failed(successStrides))
1678       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1679     assert(offset == 0 && "expected offset to be 0");
1680 
1681     // Create the descriptor.
1682     MemRefDescriptor sourceMemRef(adaptor.source());
1683     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1684 
1685     // Field 1: Copy the allocated pointer, used for malloc/free.
1686     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1687     auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
1688     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1689         loc,
1690         LLVM::LLVMPointerType::get(targetElementTy,
1691                                    srcMemRefType.getMemorySpaceAsInt()),
1692         allocatedPtr);
1693     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1694 
1695     // Field 2: Copy the actual aligned pointer to payload.
1696     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1697     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
1698                                               alignedPtr, adaptor.byte_shift());
1699     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1700         loc,
1701         LLVM::LLVMPointerType::get(targetElementTy,
1702                                    srcMemRefType.getMemorySpaceAsInt()),
1703         alignedPtr);
1704     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1705 
1706     // Field 3: The offset in the resulting type must be 0. This is because of
1707     // the type change: an offset on srcType* may not be expressible as an
1708     // offset on dstType*.
1709     targetMemRef.setOffset(rewriter, loc,
1710                            createIndexConstant(rewriter, loc, offset));
1711 
1712     // Early exit for 0-D corner case.
1713     if (viewMemRefType.getRank() == 0)
1714       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1715 
1716     // Fields 4 and 5: Update sizes and strides.
1717     if (strides.back() != 1)
1718       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1719              failure();
1720     Value stride = nullptr, nextSize = nullptr;
1721     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1722       // Update size.
1723       Value size =
1724           getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
1725       targetMemRef.setSize(rewriter, loc, i, size);
1726       // Update stride.
1727       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1728       targetMemRef.setStride(rewriter, loc, i, stride);
1729       nextSize = size;
1730     }
1731 
1732     rewriter.replaceOp(viewOp, {targetMemRef});
1733     return success();
1734   }
1735 };
1736 
1737 //===----------------------------------------------------------------------===//
1738 // AtomicRMWOpLowering
1739 //===----------------------------------------------------------------------===//
1740 
1741 /// Try to match the kind of a std.atomic_rmw to determine whether to use a
1742 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1743 static Optional<LLVM::AtomicBinOp>
1744 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1745   switch (atomicOp.kind()) {
1746   case arith::AtomicRMWKind::addf:
1747     return LLVM::AtomicBinOp::fadd;
1748   case arith::AtomicRMWKind::addi:
1749     return LLVM::AtomicBinOp::add;
1750   case arith::AtomicRMWKind::assign:
1751     return LLVM::AtomicBinOp::xchg;
1752   case arith::AtomicRMWKind::maxs:
1753     return LLVM::AtomicBinOp::max;
1754   case arith::AtomicRMWKind::maxu:
1755     return LLVM::AtomicBinOp::umax;
1756   case arith::AtomicRMWKind::mins:
1757     return LLVM::AtomicBinOp::min;
1758   case arith::AtomicRMWKind::minu:
1759     return LLVM::AtomicBinOp::umin;
1760   case arith::AtomicRMWKind::ori:
1761     return LLVM::AtomicBinOp::_or;
1762   case arith::AtomicRMWKind::andi:
1763     return LLVM::AtomicBinOp::_and;
1764   default:
1765     return llvm::None;
1766   }
1767   llvm_unreachable("Invalid AtomicRMWKind");
1768 }
1769 
1770 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1771   using Base::Base;
1772 
1773   LogicalResult
1774   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1775                   ConversionPatternRewriter &rewriter) const override {
1776     if (failed(match(atomicOp)))
1777       return failure();
1778     auto maybeKind = matchSimpleAtomicOp(atomicOp);
1779     if (!maybeKind)
1780       return failure();
1781     auto resultType = adaptor.value().getType();
1782     auto memRefType = atomicOp.getMemRefType();
1783     auto dataPtr =
1784         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
1785                              adaptor.indices(), rewriter);
1786     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1787         atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
1788         LLVM::AtomicOrdering::acq_rel);
1789     return success();
1790   }
1791 };
1792 
1793 } // namespace
1794 
1795 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
1796                                                   RewritePatternSet &patterns) {
1797   // clang-format off
1798   patterns.add<
1799       AllocaOpLowering,
1800       AllocaScopeOpLowering,
1801       AtomicRMWOpLowering,
1802       AssumeAlignmentOpLowering,
1803       DimOpLowering,
1804       GenericAtomicRMWOpLowering,
1805       GlobalMemrefOpLowering,
1806       GetGlobalMemrefOpLowering,
1807       LoadOpLowering,
1808       MemRefCastOpLowering,
1809       MemRefCopyOpLowering,
1810       MemRefReinterpretCastOpLowering,
1811       MemRefReshapeOpLowering,
1812       PrefetchOpLowering,
1813       RankOpLowering,
1814       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1815       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1816       StoreOpLowering,
1817       SubViewOpLowering,
1818       TransposeOpLowering,
1819       ViewOpLowering>(converter);
1820   // clang-format on
1821   auto allocLowering = converter.getOptions().allocLowering;
1822   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1823     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1824   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1825     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1826 }
1827 
1828 namespace {
1829 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
1830   MemRefToLLVMPass() = default;
1831 
1832   void runOnOperation() override {
1833     Operation *op = getOperation();
1834     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1835     LowerToLLVMOptions options(&getContext(),
1836                                dataLayoutAnalysis.getAtOrAbove(op));
1837     options.allocLowering =
1838         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1839                          : LowerToLLVMOptions::AllocLowering::Malloc);
1840     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1841       options.overrideIndexBitwidth(indexBitwidth);
1842 
1843     LLVMTypeConverter typeConverter(&getContext(), options,
1844                                     &dataLayoutAnalysis);
1845     RewritePatternSet patterns(&getContext());
1846     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
1847     LLVMConversionTarget target(getContext());
1848     target.addLegalOp<FuncOp>();
1849     if (failed(applyPartialConversion(op, target, std::move(patterns))))
1850       signalPassFailure();
1851   }
1852 };
1853 } // namespace
1854 
1855 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
1856   return std::make_unique<MemRefToLLVMPass>();
1857 }
1858