1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
27   AllocOpLowering(LLVMTypeConverter &converter)
28       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
29                                 converter) {}
30 
31   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
32                                           Location loc, Value sizeBytes,
33                                           Operation *op) const override {
34     // Heap allocations.
35     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
36     MemRefType memRefType = allocOp.getType();
37 
38     Value alignment;
39     if (auto alignmentAttr = allocOp.alignment()) {
40       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
41     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
42       // In the case where no alignment is specified, we may want to override
43       // `malloc's` behavior. `malloc` typically aligns at the size of the
44       // biggest scalar on a target HW. For non-scalars, use the natural
45       // alignment of the LLVM type given by the LLVM DataLayout.
46       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
47     }
48 
49     if (alignment) {
50       // Adjust the allocation size to consider alignment.
51       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
52     }
53 
54     // Allocate the underlying buffer and store a pointer to it in the MemRef
55     // descriptor.
56     Type elementPtrType = this->getElementPtrType(memRefType);
57     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
58         allocOp->getParentOfType<ModuleOp>(), getIndexType());
59     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
60                                   getVoidPtrType());
61     Value allocatedPtr =
62         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
63 
64     Value alignedPtr = allocatedPtr;
65     if (alignment) {
66       // Compute the aligned type pointer.
67       Value allocatedInt =
68           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
69       Value alignmentInt =
70           createAligned(rewriter, loc, allocatedInt, alignment);
71       alignedPtr =
72           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
73     }
74 
75     return std::make_tuple(allocatedPtr, alignedPtr);
76   }
77 };
78 
79 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
80   AlignedAllocOpLowering(LLVMTypeConverter &converter)
81       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
82                                 converter) {}
83 
84   /// Returns the memref's element size in bytes using the data layout active at
85   /// `op`.
86   // TODO: there are other places where this is used. Expose publicly?
87   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
88     const DataLayout *layout = &defaultLayout;
89     if (const DataLayoutAnalysis *analysis =
90             getTypeConverter()->getDataLayoutAnalysis()) {
91       layout = &analysis->getAbove(op);
92     }
93     Type elementType = memRefType.getElementType();
94     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
95       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
96                                                          *layout);
97     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
98       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
99           memRefElementType, *layout);
100     return layout->getTypeSize(elementType);
101   }
102 
103   /// Returns true if the memref size in bytes is known to be a multiple of
104   /// factor assuming the data layout active at `op`.
105   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
106                               Operation *op) const {
107     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
108     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
109       if (ShapedType::isDynamic(type.getDimSize(i)))
110         continue;
111       sizeDivisor = sizeDivisor * type.getDimSize(i);
112     }
113     return sizeDivisor % factor == 0;
114   }
115 
116   /// Returns the alignment to be used for the allocation call itself.
117   /// aligned_alloc requires the allocation size to be a power of two, and the
118   /// allocation size to be a multiple of alignment,
119   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
120     if (Optional<uint64_t> alignment = allocOp.alignment())
121       return *alignment;
122 
123     // Whenever we don't have alignment set, we will use an alignment
124     // consistent with the element type; since the allocation size has to be a
125     // power of two, we will bump to the next power of two if it already isn't.
126     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
127     return std::max(kMinAlignedAllocAlignment,
128                     llvm::PowerOf2Ceil(eltSizeBytes));
129   }
130 
131   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
132                                           Location loc, Value sizeBytes,
133                                           Operation *op) const override {
134     // Heap allocations.
135     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
136     MemRefType memRefType = allocOp.getType();
137     int64_t alignment = getAllocationAlignment(allocOp);
138     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
139 
140     // aligned_alloc requires size to be a multiple of alignment; we will pad
141     // the size to the next multiple if necessary.
142     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
143       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
144 
145     Type elementPtrType = this->getElementPtrType(memRefType);
146     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
147         allocOp->getParentOfType<ModuleOp>(), getIndexType());
148     auto results =
149         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
150                        getVoidPtrType());
151     Value allocatedPtr =
152         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
153 
154     return std::make_tuple(allocatedPtr, allocatedPtr);
155   }
156 
157   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
158   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
159 
160   /// Default layout to use in absence of the corresponding analysis.
161   DataLayout defaultLayout;
162 };
163 
164 // Out of line definition, required till C++17.
165 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
166 
167 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
168   AllocaOpLowering(LLVMTypeConverter &converter)
169       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
170                                 converter) {}
171 
172   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
173   /// is set to null for stack allocations. `accessAlignment` is set if
174   /// alignment is needed post allocation (for eg. in conjunction with malloc).
175   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
176                                           Location loc, Value sizeBytes,
177                                           Operation *op) const override {
178 
179     // With alloca, one gets a pointer to the element type right away.
180     // For stack allocations.
181     auto allocaOp = cast<memref::AllocaOp>(op);
182     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
183 
184     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
185         loc, elementPtrType, sizeBytes,
186         allocaOp.alignment() ? *allocaOp.alignment() : 0);
187 
188     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
189   }
190 };
191 
192 struct AllocaScopeOpLowering
193     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
194   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
195 
196   LogicalResult
197   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
198                   ConversionPatternRewriter &rewriter) const override {
199     OpBuilder::InsertionGuard guard(rewriter);
200     Location loc = allocaScopeOp.getLoc();
201 
202     // Split the current block before the AllocaScopeOp to create the inlining
203     // point.
204     auto *currentBlock = rewriter.getInsertionBlock();
205     auto *remainingOpsBlock =
206         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
207     Block *continueBlock;
208     if (allocaScopeOp.getNumResults() == 0) {
209       continueBlock = remainingOpsBlock;
210     } else {
211       continueBlock = rewriter.createBlock(
212           remainingOpsBlock, allocaScopeOp.getResultTypes(),
213           SmallVector<Location>(allocaScopeOp->getNumResults(),
214                                 allocaScopeOp.getLoc()));
215       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
216     }
217 
218     // Inline body region.
219     Block *beforeBody = &allocaScopeOp.bodyRegion().front();
220     Block *afterBody = &allocaScopeOp.bodyRegion().back();
221     rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
222 
223     // Save stack and then branch into the body of the region.
224     rewriter.setInsertionPointToEnd(currentBlock);
225     auto stackSaveOp =
226         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
227     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
228 
229     // Replace the alloca_scope return with a branch that jumps out of the body.
230     // Stack restore before leaving the body region.
231     rewriter.setInsertionPointToEnd(afterBody);
232     auto returnOp =
233         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
234     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
235         returnOp, returnOp.results(), continueBlock);
236 
237     // Insert stack restore before jumping out the body of the region.
238     rewriter.setInsertionPoint(branchOp);
239     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
240 
241     // Replace the op with values return from the body region.
242     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
243 
244     return success();
245   }
246 };
247 
248 struct AssumeAlignmentOpLowering
249     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
250   using ConvertOpToLLVMPattern<
251       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
252 
253   LogicalResult
254   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
255                   ConversionPatternRewriter &rewriter) const override {
256     Value memref = adaptor.memref();
257     unsigned alignment = op.alignment();
258     auto loc = op.getLoc();
259 
260     MemRefDescriptor memRefDescriptor(memref);
261     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
262 
263     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
264     // the asserted memref.alignedPtr isn't used anywhere else, as the real
265     // users like load/store/views always re-extract memref.alignedPtr as they
266     // get lowered.
267     //
268     // This relies on LLVM's CSE optimization (potentially after SROA), since
269     // after CSE all memref.alignedPtr instances get de-duplicated into the same
270     // pointer SSA value.
271     auto intPtrType =
272         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
273     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
274     Value mask =
275         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
276     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
277     rewriter.create<LLVM::AssumeOp>(
278         loc, rewriter.create<LLVM::ICmpOp>(
279                  loc, LLVM::ICmpPredicate::eq,
280                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
281 
282     rewriter.eraseOp(op);
283     return success();
284   }
285 };
286 
287 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
288 // The memref descriptor being an SSA value, there is no need to clean it up
289 // in any way.
290 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
291   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
292 
293   explicit DeallocOpLowering(LLVMTypeConverter &converter)
294       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
295 
296   LogicalResult
297   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
298                   ConversionPatternRewriter &rewriter) const override {
299     // Insert the `free` declaration if it is not already present.
300     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
301     MemRefDescriptor memref(adaptor.memref());
302     Value casted = rewriter.create<LLVM::BitcastOp>(
303         op.getLoc(), getVoidPtrType(),
304         memref.allocatedPtr(rewriter, op.getLoc()));
305     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
306         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
307     return success();
308   }
309 };
310 
311 // A `dim` is converted to a constant for static sizes and to an access to the
312 // size stored in the memref descriptor for dynamic sizes.
313 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
314   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
315 
316   LogicalResult
317   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
318                   ConversionPatternRewriter &rewriter) const override {
319     Type operandType = dimOp.source().getType();
320     if (operandType.isa<UnrankedMemRefType>()) {
321       rewriter.replaceOp(
322           dimOp, {extractSizeOfUnrankedMemRef(
323                      operandType, dimOp, adaptor.getOperands(), rewriter)});
324 
325       return success();
326     }
327     if (operandType.isa<MemRefType>()) {
328       rewriter.replaceOp(
329           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
330                                             adaptor.getOperands(), rewriter)});
331       return success();
332     }
333     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
334   }
335 
336 private:
337   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
338                                     OpAdaptor adaptor,
339                                     ConversionPatternRewriter &rewriter) const {
340     Location loc = dimOp.getLoc();
341 
342     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
343     auto scalarMemRefType =
344         MemRefType::get({}, unrankedMemRefType.getElementType());
345     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
346 
347     // Extract pointer to the underlying ranked descriptor and bitcast it to a
348     // memref<element_type> descriptor pointer to minimize the number of GEP
349     // operations.
350     UnrankedMemRefDescriptor unrankedDesc(adaptor.source());
351     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
352     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
353         loc,
354         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
355                                    addressSpace),
356         underlyingRankedDesc);
357 
358     // Get pointer to offset field of memref<element_type> descriptor.
359     Type indexPtrTy = LLVM::LLVMPointerType::get(
360         getTypeConverter()->getIndexType(), addressSpace);
361     Value two = rewriter.create<LLVM::ConstantOp>(
362         loc, typeConverter->convertType(rewriter.getI32Type()),
363         rewriter.getI32IntegerAttr(2));
364     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
365         loc, indexPtrTy, scalarMemRefDescPtr,
366         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
367 
368     // The size value that we have to extract can be obtained using GEPop with
369     // `dimOp.index() + 1` index argument.
370     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
371         loc, createIndexConstant(rewriter, loc, 1), adaptor.index());
372     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
373                                                  ValueRange({idxPlusOne}));
374     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
375   }
376 
377   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
378     if (Optional<int64_t> idx = dimOp.getConstantIndex())
379       return idx;
380 
381     if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>())
382       return constantOp.getValue()
383           .cast<IntegerAttr>()
384           .getValue()
385           .getSExtValue();
386 
387     return llvm::None;
388   }
389 
390   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
391                                   OpAdaptor adaptor,
392                                   ConversionPatternRewriter &rewriter) const {
393     Location loc = dimOp.getLoc();
394 
395     // Take advantage if index is constant.
396     MemRefType memRefType = operandType.cast<MemRefType>();
397     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
398       int64_t i = index.getValue();
399       if (memRefType.isDynamicDim(i)) {
400         // extract dynamic size from the memref descriptor.
401         MemRefDescriptor descriptor(adaptor.source());
402         return descriptor.size(rewriter, loc, i);
403       }
404       // Use constant for static size.
405       int64_t dimSize = memRefType.getDimSize(i);
406       return createIndexConstant(rewriter, loc, dimSize);
407     }
408     Value index = adaptor.index();
409     int64_t rank = memRefType.getRank();
410     MemRefDescriptor memrefDescriptor(adaptor.source());
411     return memrefDescriptor.size(rewriter, loc, index, rank);
412   }
413 };
414 
415 /// Returns the LLVM type of the global variable given the memref type `type`.
416 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
417                                           LLVMTypeConverter &typeConverter) {
418   // LLVM type for a global memref will be a multi-dimension array. For
419   // declarations or uninitialized global memrefs, we can potentially flatten
420   // this to a 1D array. However, for memref.global's with an initial value,
421   // we do not intend to flatten the ElementsAttribute when going from std ->
422   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
423   Type elementType = typeConverter.convertType(type.getElementType());
424   Type arrayTy = elementType;
425   // Shape has the outermost dim at index 0, so need to walk it backwards
426   for (int64_t dim : llvm::reverse(type.getShape()))
427     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
428   return arrayTy;
429 }
430 
431 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
432 struct GlobalMemrefOpLowering
433     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
434   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
435 
436   LogicalResult
437   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
438                   ConversionPatternRewriter &rewriter) const override {
439     MemRefType type = global.type();
440     if (!isConvertibleAndHasIdentityMaps(type))
441       return failure();
442 
443     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
444 
445     LLVM::Linkage linkage =
446         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
447 
448     Attribute initialValue = nullptr;
449     if (!global.isExternal() && !global.isUninitialized()) {
450       auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
451       initialValue = elementsAttr;
452 
453       // For scalar memrefs, the global variable created is of the element type,
454       // so unpack the elements attribute to extract the value.
455       if (type.getRank() == 0)
456         initialValue = elementsAttr.getSplatValue<Attribute>();
457     }
458 
459     uint64_t alignment = global.alignment().getValueOr(0);
460 
461     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
462         global, arrayTy, global.constant(), linkage, global.sym_name(),
463         initialValue, alignment, type.getMemorySpaceAsInt());
464     if (!global.isExternal() && global.isUninitialized()) {
465       Block *blk = new Block();
466       newGlobal.getInitializerRegion().push_back(blk);
467       rewriter.setInsertionPointToStart(blk);
468       Value undef[] = {
469           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
470       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
471     }
472     return success();
473   }
474 };
475 
476 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
477 /// the first element stashed into the descriptor. This reuses
478 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
479 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
480   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
481       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
482                                 converter) {}
483 
484   /// Buffer "allocation" for memref.get_global op is getting the address of
485   /// the global variable referenced.
486   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
487                                           Location loc, Value sizeBytes,
488                                           Operation *op) const override {
489     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
490     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
491     unsigned memSpace = type.getMemorySpaceAsInt();
492 
493     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
494     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
495         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
496 
497     // Get the address of the first element in the array by creating a GEP with
498     // the address of the GV as the base, and (rank + 1) number of 0 indices.
499     Type elementType = typeConverter->convertType(type.getElementType());
500     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
501 
502     SmallVector<Value> operands;
503     operands.insert(operands.end(), type.getRank() + 1,
504                     createIndexConstant(rewriter, loc, 0));
505     auto gep =
506         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
507 
508     // We do not expect the memref obtained using `memref.get_global` to be
509     // ever deallocated. Set the allocated pointer to be known bad value to
510     // help debug if that ever happens.
511     auto intPtrType = getIntPtrType(memSpace);
512     Value deadBeefConst =
513         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
514     auto deadBeefPtr =
515         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
516 
517     // Both allocated and aligned pointers are same. We could potentially stash
518     // a nullptr for the allocated pointer since we do not expect any dealloc.
519     return std::make_tuple(deadBeefPtr, gep);
520   }
521 };
522 
523 // Common base for load and store operations on MemRefs. Restricts the match
524 // to supported MemRef types. Provides functionality to emit code accessing a
525 // specific element of the underlying data buffer.
526 template <typename Derived>
527 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
528   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
529   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
530   using Base = LoadStoreOpLowering<Derived>;
531 
532   LogicalResult match(Derived op) const override {
533     MemRefType type = op.getMemRefType();
534     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
535   }
536 };
537 
538 // Load operation is lowered to obtaining a pointer to the indexed element
539 // and loading it.
540 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
541   using Base::Base;
542 
543   LogicalResult
544   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
545                   ConversionPatternRewriter &rewriter) const override {
546     auto type = loadOp.getMemRefType();
547 
548     Value dataPtr = getStridedElementPtr(
549         loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter);
550     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
551     return success();
552   }
553 };
554 
555 // Store operation is lowered to obtaining a pointer to the indexed element,
556 // and storing the given value to it.
557 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
558   using Base::Base;
559 
560   LogicalResult
561   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
562                   ConversionPatternRewriter &rewriter) const override {
563     auto type = op.getMemRefType();
564 
565     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(),
566                                          adaptor.indices(), rewriter);
567     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr);
568     return success();
569   }
570 };
571 
572 // The prefetch operation is lowered in a way similar to the load operation
573 // except that the llvm.prefetch operation is used for replacement.
574 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
575   using Base::Base;
576 
577   LogicalResult
578   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
579                   ConversionPatternRewriter &rewriter) const override {
580     auto type = prefetchOp.getMemRefType();
581     auto loc = prefetchOp.getLoc();
582 
583     Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(),
584                                          adaptor.indices(), rewriter);
585 
586     // Replace with llvm.prefetch.
587     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
588     auto isWrite = rewriter.create<LLVM::ConstantOp>(
589         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
590     auto localityHint = rewriter.create<LLVM::ConstantOp>(
591         loc, llvmI32Type,
592         rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
593     auto isData = rewriter.create<LLVM::ConstantOp>(
594         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
595 
596     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
597                                                 localityHint, isData);
598     return success();
599   }
600 };
601 
602 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
603   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
604 
605   LogicalResult
606   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
607                   ConversionPatternRewriter &rewriter) const override {
608     Location loc = op.getLoc();
609     Type operandType = op.memref().getType();
610     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
611       UnrankedMemRefDescriptor desc(adaptor.memref());
612       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
613       return success();
614     }
615     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
616       rewriter.replaceOp(
617           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
618       return success();
619     }
620     return failure();
621   }
622 };
623 
624 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
625   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
626 
627   LogicalResult match(memref::CastOp memRefCastOp) const override {
628     Type srcType = memRefCastOp.getOperand().getType();
629     Type dstType = memRefCastOp.getType();
630 
631     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
632     // used for type erasure. For now they must preserve underlying element type
633     // and require source and result type to have the same rank. Therefore,
634     // perform a sanity check that the underlying structs are the same. Once op
635     // semantics are relaxed we can revisit.
636     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
637       return success(typeConverter->convertType(srcType) ==
638                      typeConverter->convertType(dstType));
639 
640     // At least one of the operands is unranked type
641     assert(srcType.isa<UnrankedMemRefType>() ||
642            dstType.isa<UnrankedMemRefType>());
643 
644     // Unranked to unranked cast is disallowed
645     return !(srcType.isa<UnrankedMemRefType>() &&
646              dstType.isa<UnrankedMemRefType>())
647                ? success()
648                : failure();
649   }
650 
651   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
652                ConversionPatternRewriter &rewriter) const override {
653     auto srcType = memRefCastOp.getOperand().getType();
654     auto dstType = memRefCastOp.getType();
655     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
656     auto loc = memRefCastOp.getLoc();
657 
658     // For ranked/ranked case, just keep the original descriptor.
659     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
660       return rewriter.replaceOp(memRefCastOp, {adaptor.source()});
661 
662     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
663       // Casting ranked to unranked memref type
664       // Set the rank in the destination from the memref type
665       // Allocate space on the stack and copy the src memref descriptor
666       // Set the ptr in the destination to the stack space
667       auto srcMemRefType = srcType.cast<MemRefType>();
668       int64_t rank = srcMemRefType.getRank();
669       // ptr = AllocaOp sizeof(MemRefDescriptor)
670       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
671           loc, adaptor.source(), rewriter);
672       // voidptr = BitCastOp srcType* to void*
673       auto voidPtr =
674           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
675               .getResult();
676       // rank = ConstantOp srcRank
677       auto rankVal = rewriter.create<LLVM::ConstantOp>(
678           loc, typeConverter->convertType(rewriter.getIntegerType(64)),
679           rewriter.getI64IntegerAttr(rank));
680       // undef = UndefOp
681       UnrankedMemRefDescriptor memRefDesc =
682           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
683       // d1 = InsertValueOp undef, rank, 0
684       memRefDesc.setRank(rewriter, loc, rankVal);
685       // d2 = InsertValueOp d1, voidptr, 1
686       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
687       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
688 
689     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
690       // Casting from unranked type to ranked.
691       // The operation is assumed to be doing a correct cast. If the destination
692       // type mismatches the unranked the type, it is undefined behavior.
693       UnrankedMemRefDescriptor memRefDesc(adaptor.source());
694       // ptr = ExtractValueOp src, 1
695       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
696       // castPtr = BitCastOp i8* to structTy*
697       auto castPtr =
698           rewriter
699               .create<LLVM::BitcastOp>(
700                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
701               .getResult();
702       // struct = LoadOp castPtr
703       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
704       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
705     } else {
706       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
707     }
708   }
709 };
710 
711 /// Pattern to lower a `memref.copy` to llvm.
712 ///
713 /// For memrefs with identity layouts, the copy is lowered to the llvm
714 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
715 /// to the generic `MemrefCopyFn`.
716 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
717   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
718 
719   LogicalResult
720   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
721                           ConversionPatternRewriter &rewriter) const {
722     auto loc = op.getLoc();
723     auto srcType = op.source().getType().dyn_cast<MemRefType>();
724 
725     MemRefDescriptor srcDesc(adaptor.source());
726 
727     // Compute number of elements.
728     Value numElements = rewriter.create<LLVM::ConstantOp>(
729         loc, getIndexType(), rewriter.getIndexAttr(1));
730     for (int pos = 0; pos < srcType.getRank(); ++pos) {
731       auto size = srcDesc.size(rewriter, loc, pos);
732       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
733     }
734 
735     // Get element size.
736     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
737     // Compute total.
738     Value totalSize =
739         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
740 
741     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
742     MemRefDescriptor targetDesc(adaptor.target());
743     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
744     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
745         loc, typeConverter->convertType(rewriter.getI1Type()),
746         rewriter.getBoolAttr(false));
747     rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize,
748                                     isVolatile);
749     rewriter.eraseOp(op);
750 
751     return success();
752   }
753 
754   LogicalResult
755   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
756                              ConversionPatternRewriter &rewriter) const {
757     auto loc = op.getLoc();
758     auto srcType = op.source().getType().cast<BaseMemRefType>();
759     auto targetType = op.target().getType().cast<BaseMemRefType>();
760 
761     // First make sure we have an unranked memref descriptor representation.
762     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
763       auto rank = rewriter.create<LLVM::ConstantOp>(
764           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
765       auto *typeConverter = getTypeConverter();
766       auto ptr =
767           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
768       auto voidPtr =
769           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
770               .getResult();
771       auto unrankedType =
772           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
773       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
774                                             unrankedType,
775                                             ValueRange{rank, voidPtr});
776     };
777 
778     Value unrankedSource = srcType.hasRank()
779                                ? makeUnranked(adaptor.source(), srcType)
780                                : adaptor.source();
781     Value unrankedTarget = targetType.hasRank()
782                                ? makeUnranked(adaptor.target(), targetType)
783                                : adaptor.target();
784 
785     // Now promote the unranked descriptors to the stack.
786     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
787                                                  rewriter.getIndexAttr(1));
788     auto promote = [&](Value desc) {
789       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
790       auto allocated =
791           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
792       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
793       return allocated;
794     };
795 
796     auto sourcePtr = promote(unrankedSource);
797     auto targetPtr = promote(unrankedTarget);
798 
799     auto elemSize = rewriter.create<LLVM::ConstantOp>(
800         loc, getIndexType(),
801         rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8));
802     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
803         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
804     rewriter.create<LLVM::CallOp>(loc, copyFn,
805                                   ValueRange{elemSize, sourcePtr, targetPtr});
806     rewriter.eraseOp(op);
807 
808     return success();
809   }
810 
811   LogicalResult
812   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
813                   ConversionPatternRewriter &rewriter) const override {
814     auto srcType = op.source().getType().cast<BaseMemRefType>();
815     auto targetType = op.target().getType().cast<BaseMemRefType>();
816 
817     if (srcType.hasRank() &&
818         srcType.cast<MemRefType>().getLayout().isIdentity() &&
819         targetType.hasRank() &&
820         targetType.cast<MemRefType>().getLayout().isIdentity())
821       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
822 
823     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
824   }
825 };
826 
827 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
828 /// memref type. In unranked case, the fields are extracted from the underlying
829 /// ranked descriptor.
830 static void extractPointersAndOffset(Location loc,
831                                      ConversionPatternRewriter &rewriter,
832                                      LLVMTypeConverter &typeConverter,
833                                      Value originalOperand,
834                                      Value convertedOperand,
835                                      Value *allocatedPtr, Value *alignedPtr,
836                                      Value *offset = nullptr) {
837   Type operandType = originalOperand.getType();
838   if (operandType.isa<MemRefType>()) {
839     MemRefDescriptor desc(convertedOperand);
840     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
841     *alignedPtr = desc.alignedPtr(rewriter, loc);
842     if (offset != nullptr)
843       *offset = desc.offset(rewriter, loc);
844     return;
845   }
846 
847   unsigned memorySpace =
848       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
849   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
850   Type llvmElementType = typeConverter.convertType(elementType);
851   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
852       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
853 
854   // Extract pointer to the underlying ranked memref descriptor and cast it to
855   // ElemType**.
856   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
857   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
858 
859   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
860       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
861   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
862       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
863   if (offset != nullptr) {
864     *offset = UnrankedMemRefDescriptor::offset(
865         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
866   }
867 }
868 
869 struct MemRefReinterpretCastOpLowering
870     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
871   using ConvertOpToLLVMPattern<
872       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
873 
874   LogicalResult
875   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
876                   ConversionPatternRewriter &rewriter) const override {
877     Type srcType = castOp.source().getType();
878 
879     Value descriptor;
880     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
881                                                adaptor, &descriptor)))
882       return failure();
883     rewriter.replaceOp(castOp, {descriptor});
884     return success();
885   }
886 
887 private:
888   LogicalResult convertSourceMemRefToDescriptor(
889       ConversionPatternRewriter &rewriter, Type srcType,
890       memref::ReinterpretCastOp castOp,
891       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
892     MemRefType targetMemRefType =
893         castOp.getResult().getType().cast<MemRefType>();
894     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
895                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
896     if (!llvmTargetDescriptorTy)
897       return failure();
898 
899     // Create descriptor.
900     Location loc = castOp.getLoc();
901     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
902 
903     // Set allocated and aligned pointers.
904     Value allocatedPtr, alignedPtr;
905     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
906                              castOp.source(), adaptor.source(), &allocatedPtr,
907                              &alignedPtr);
908     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
909     desc.setAlignedPtr(rewriter, loc, alignedPtr);
910 
911     // Set offset.
912     if (castOp.isDynamicOffset(0))
913       desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
914     else
915       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
916 
917     // Set sizes and strides.
918     unsigned dynSizeId = 0;
919     unsigned dynStrideId = 0;
920     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
921       if (castOp.isDynamicSize(i))
922         desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
923       else
924         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
925 
926       if (castOp.isDynamicStride(i))
927         desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
928       else
929         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
930     }
931     *descriptor = desc;
932     return success();
933   }
934 };
935 
936 struct MemRefReshapeOpLowering
937     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
938   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
939 
940   LogicalResult
941   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
942                   ConversionPatternRewriter &rewriter) const override {
943     Type srcType = reshapeOp.source().getType();
944 
945     Value descriptor;
946     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
947                                                adaptor, &descriptor)))
948       return failure();
949     rewriter.replaceOp(reshapeOp, {descriptor});
950     return success();
951   }
952 
953 private:
954   LogicalResult
955   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
956                                   Type srcType, memref::ReshapeOp reshapeOp,
957                                   memref::ReshapeOp::Adaptor adaptor,
958                                   Value *descriptor) const {
959     // Conversion for statically-known shape args is performed via
960     // `memref_reinterpret_cast`.
961     auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
962     if (shapeMemRefType.hasStaticShape())
963       return failure();
964 
965     // The shape is a rank-1 tensor with unknown length.
966     Location loc = reshapeOp.getLoc();
967     MemRefDescriptor shapeDesc(adaptor.shape());
968     Value resultRank = shapeDesc.size(rewriter, loc, 0);
969 
970     // Extract address space and element type.
971     auto targetType =
972         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
973     unsigned addressSpace = targetType.getMemorySpaceAsInt();
974     Type elementType = targetType.getElementType();
975 
976     // Create the unranked memref descriptor that holds the ranked one. The
977     // inner descriptor is allocated on stack.
978     auto targetDesc = UnrankedMemRefDescriptor::undef(
979         rewriter, loc, typeConverter->convertType(targetType));
980     targetDesc.setRank(rewriter, loc, resultRank);
981     SmallVector<Value, 4> sizes;
982     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
983                                            targetDesc, sizes);
984     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
985         loc, getVoidPtrType(), sizes.front(), llvm::None);
986     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
987 
988     // Extract pointers and offset from the source memref.
989     Value allocatedPtr, alignedPtr, offset;
990     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
991                              reshapeOp.source(), adaptor.source(),
992                              &allocatedPtr, &alignedPtr, &offset);
993 
994     // Set pointers and offset.
995     Type llvmElementType = typeConverter->convertType(elementType);
996     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
997         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
998     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
999                                               elementPtrPtrType, allocatedPtr);
1000     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1001                                             underlyingDescPtr,
1002                                             elementPtrPtrType, alignedPtr);
1003     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1004                                         underlyingDescPtr, elementPtrPtrType,
1005                                         offset);
1006 
1007     // Use the offset pointer as base for further addressing. Copy over the new
1008     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1009     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1010         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1011         elementPtrPtrType);
1012     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1013         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1014     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1015     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1016     Value resultRankMinusOne =
1017         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1018 
1019     Block *initBlock = rewriter.getInsertionBlock();
1020     Type indexType = getTypeConverter()->getIndexType();
1021     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1022 
1023     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1024                                             {indexType, indexType}, {loc, loc});
1025 
1026     // Move the remaining initBlock ops to condBlock.
1027     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1028     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1029 
1030     rewriter.setInsertionPointToEnd(initBlock);
1031     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1032                                 condBlock);
1033     rewriter.setInsertionPointToStart(condBlock);
1034     Value indexArg = condBlock->getArgument(0);
1035     Value strideArg = condBlock->getArgument(1);
1036 
1037     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1038     Value pred = rewriter.create<LLVM::ICmpOp>(
1039         loc, IntegerType::get(rewriter.getContext(), 1),
1040         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1041 
1042     Block *bodyBlock =
1043         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1044     rewriter.setInsertionPointToStart(bodyBlock);
1045 
1046     // Copy size from shape to descriptor.
1047     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1048     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1049         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1050     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1051     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1052                                       targetSizesBase, indexArg, size);
1053 
1054     // Write stride value and compute next one.
1055     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1056                                         targetStridesBase, indexArg, strideArg);
1057     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1058 
1059     // Decrement loop counter and branch back.
1060     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1061     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1062                                 condBlock);
1063 
1064     Block *remainder =
1065         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1066 
1067     // Hook up the cond exit to the remainder.
1068     rewriter.setInsertionPointToEnd(condBlock);
1069     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1070                                     llvm::None);
1071 
1072     // Reset position to beginning of new remainder block.
1073     rewriter.setInsertionPointToStart(remainder);
1074 
1075     *descriptor = targetDesc;
1076     return success();
1077   }
1078 };
1079 
1080 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1081 /// `Value`s.
1082 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1083                                       Type &llvmIndexType,
1084                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1085   return llvm::to_vector<4>(
1086       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1087         if (auto attr = value.dyn_cast<Attribute>())
1088           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1089         return value.get<Value>();
1090       }));
1091 }
1092 
1093 /// Compute a map that for a given dimension of the expanded type gives the
1094 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1095 /// the `reassocation` maps.
1096 static DenseMap<int64_t, int64_t>
1097 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1098   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1099   for (auto &en : enumerate(reassociation)) {
1100     for (auto dim : en.value())
1101       expandedDimToCollapsedDim[dim] = en.index();
1102   }
1103   return expandedDimToCollapsedDim;
1104 }
1105 
1106 static OpFoldResult
1107 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1108                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1109                          MemRefDescriptor &inDesc,
1110                          ArrayRef<int64_t> inStaticShape,
1111                          ArrayRef<ReassociationIndices> reassocation,
1112                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1113   int64_t outDimSize = outStaticShape[outDimIndex];
1114   if (!ShapedType::isDynamic(outDimSize))
1115     return b.getIndexAttr(outDimSize);
1116 
1117   // Calculate the multiplication of all the out dim sizes except the
1118   // current dim.
1119   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1120   int64_t otherDimSizesMul = 1;
1121   for (auto otherDimIndex : reassocation[inDimIndex]) {
1122     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1123       continue;
1124     int64_t otherDimSize = outStaticShape[otherDimIndex];
1125     assert(!ShapedType::isDynamic(otherDimSize) &&
1126            "single dimension cannot be expanded into multiple dynamic "
1127            "dimensions");
1128     otherDimSizesMul *= otherDimSize;
1129   }
1130 
1131   // outDimSize = inDimSize / otherOutDimSizesMul
1132   int64_t inDimSize = inStaticShape[inDimIndex];
1133   Value inDimSizeDynamic =
1134       ShapedType::isDynamic(inDimSize)
1135           ? inDesc.size(b, loc, inDimIndex)
1136           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1137                                        b.getIndexAttr(inDimSize));
1138   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1139       loc, inDimSizeDynamic,
1140       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1141                                  b.getIndexAttr(otherDimSizesMul)));
1142   return outDimSizeDynamic;
1143 }
1144 
1145 static OpFoldResult getCollapsedOutputDimSize(
1146     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1147     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1148     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1149   if (!ShapedType::isDynamic(outDimSize))
1150     return b.getIndexAttr(outDimSize);
1151 
1152   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1153   Value outDimSizeDynamic = c1;
1154   for (auto inDimIndex : reassocation[outDimIndex]) {
1155     int64_t inDimSize = inStaticShape[inDimIndex];
1156     Value inDimSizeDynamic =
1157         ShapedType::isDynamic(inDimSize)
1158             ? inDesc.size(b, loc, inDimIndex)
1159             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1160                                          b.getIndexAttr(inDimSize));
1161     outDimSizeDynamic =
1162         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1163   }
1164   return outDimSizeDynamic;
1165 }
1166 
1167 static SmallVector<OpFoldResult, 4>
1168 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1169                         ArrayRef<ReassociationIndices> reassocation,
1170                         ArrayRef<int64_t> inStaticShape,
1171                         MemRefDescriptor &inDesc,
1172                         ArrayRef<int64_t> outStaticShape) {
1173   return llvm::to_vector<4>(llvm::map_range(
1174       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1175         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1176                                          outStaticShape[outDimIndex],
1177                                          inStaticShape, inDesc, reassocation);
1178       }));
1179 }
1180 
1181 static SmallVector<OpFoldResult, 4>
1182 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1183                        ArrayRef<ReassociationIndices> reassocation,
1184                        ArrayRef<int64_t> inStaticShape,
1185                        MemRefDescriptor &inDesc,
1186                        ArrayRef<int64_t> outStaticShape) {
1187   DenseMap<int64_t, int64_t> outDimToInDimMap =
1188       getExpandedDimToCollapsedDimMap(reassocation);
1189   return llvm::to_vector<4>(llvm::map_range(
1190       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1191         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1192                                         outStaticShape, inDesc, inStaticShape,
1193                                         reassocation, outDimToInDimMap);
1194       }));
1195 }
1196 
1197 static SmallVector<Value>
1198 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1199                       ArrayRef<ReassociationIndices> reassocation,
1200                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1201                       ArrayRef<int64_t> outStaticShape) {
1202   return outStaticShape.size() < inStaticShape.size()
1203              ? getAsValues(b, loc, llvmIndexType,
1204                            getCollapsedOutputShape(b, loc, llvmIndexType,
1205                                                    reassocation, inStaticShape,
1206                                                    inDesc, outStaticShape))
1207              : getAsValues(b, loc, llvmIndexType,
1208                            getExpandedOutputShape(b, loc, llvmIndexType,
1209                                                   reassocation, inStaticShape,
1210                                                   inDesc, outStaticShape));
1211 }
1212 
1213 // ReshapeOp creates a new view descriptor of the proper rank.
1214 // For now, the only conversion supported is for target MemRef with static sizes
1215 // and strides.
1216 template <typename ReshapeOp>
1217 class ReassociatingReshapeOpConversion
1218     : public ConvertOpToLLVMPattern<ReshapeOp> {
1219 public:
1220   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1221   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1222 
1223   LogicalResult
1224   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1225                   ConversionPatternRewriter &rewriter) const override {
1226     MemRefType dstType = reshapeOp.getResultType();
1227     MemRefType srcType = reshapeOp.getSrcType();
1228 
1229     // The condition on the layouts can be ignored when all shapes are static.
1230     if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) {
1231       if (!srcType.getLayout().isIdentity() ||
1232           !dstType.getLayout().isIdentity()) {
1233         return rewriter.notifyMatchFailure(
1234             reshapeOp, "only empty layout map is supported");
1235       }
1236     }
1237 
1238     int64_t offset;
1239     SmallVector<int64_t, 4> strides;
1240     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1241       return rewriter.notifyMatchFailure(
1242           reshapeOp, "failed to get stride and offset exprs");
1243     }
1244 
1245     MemRefDescriptor srcDesc(adaptor.src());
1246     Location loc = reshapeOp->getLoc();
1247     auto dstDesc = MemRefDescriptor::undef(
1248         rewriter, loc, this->typeConverter->convertType(dstType));
1249     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1250     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1251     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1252 
1253     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1254     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1255     Type llvmIndexType =
1256         this->typeConverter->convertType(rewriter.getIndexType());
1257     SmallVector<Value> dstShape = getDynamicOutputShape(
1258         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1259         srcStaticShape, srcDesc, dstStaticShape);
1260     for (auto &en : llvm::enumerate(dstShape))
1261       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1262 
1263     auto isStaticStride = [](int64_t stride) {
1264       return !ShapedType::isDynamicStrideOrOffset(stride);
1265     };
1266     if (llvm::all_of(strides, isStaticStride)) {
1267       for (auto &en : llvm::enumerate(strides))
1268         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1269     } else {
1270       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1271                                                    rewriter.getIndexAttr(1));
1272       Value stride = c1;
1273       for (auto dimIndex :
1274            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1275         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1276         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1277       }
1278     }
1279     rewriter.replaceOp(reshapeOp, {dstDesc});
1280     return success();
1281   }
1282 };
1283 
1284 /// Conversion pattern that transforms a subview op into:
1285 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1286 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1287 ///      and stride.
1288 /// The subview op is replaced by the descriptor.
1289 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1290   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1291 
1292   LogicalResult
1293   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1294                   ConversionPatternRewriter &rewriter) const override {
1295     auto loc = subViewOp.getLoc();
1296 
1297     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
1298     auto sourceElementTy =
1299         typeConverter->convertType(sourceMemRefType.getElementType());
1300 
1301     auto viewMemRefType = subViewOp.getType();
1302     auto inferredType = memref::SubViewOp::inferResultType(
1303                             subViewOp.getSourceType(),
1304                             extractFromI64ArrayAttr(subViewOp.static_offsets()),
1305                             extractFromI64ArrayAttr(subViewOp.static_sizes()),
1306                             extractFromI64ArrayAttr(subViewOp.static_strides()))
1307                             .cast<MemRefType>();
1308     auto targetElementTy =
1309         typeConverter->convertType(viewMemRefType.getElementType());
1310     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1311     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1312         !LLVM::isCompatibleType(sourceElementTy) ||
1313         !LLVM::isCompatibleType(targetElementTy) ||
1314         !LLVM::isCompatibleType(targetDescTy))
1315       return failure();
1316 
1317     // Extract the offset and strides from the type.
1318     int64_t offset;
1319     SmallVector<int64_t, 4> strides;
1320     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1321     if (failed(successStrides))
1322       return failure();
1323 
1324     // Create the descriptor.
1325     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1326       return failure();
1327     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1328     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1329 
1330     // Copy the buffer pointer from the old descriptor to the new one.
1331     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1332     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1333         loc,
1334         LLVM::LLVMPointerType::get(targetElementTy,
1335                                    viewMemRefType.getMemorySpaceAsInt()),
1336         extracted);
1337     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1338 
1339     // Copy the aligned pointer from the old descriptor to the new one.
1340     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1341     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1342         loc,
1343         LLVM::LLVMPointerType::get(targetElementTy,
1344                                    viewMemRefType.getMemorySpaceAsInt()),
1345         extracted);
1346     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1347 
1348     size_t inferredShapeRank = inferredType.getRank();
1349     size_t resultShapeRank = viewMemRefType.getRank();
1350 
1351     // Extract strides needed to compute offset.
1352     SmallVector<Value, 4> strideValues;
1353     strideValues.reserve(inferredShapeRank);
1354     for (unsigned i = 0; i < inferredShapeRank; ++i)
1355       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1356 
1357     // Offset.
1358     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1359     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1360       targetMemRef.setConstantOffset(rewriter, loc, offset);
1361     } else {
1362       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1363       // `inferredShapeRank` may be larger than the number of offset operands
1364       // because of trailing semantics. In this case, the offset is guaranteed
1365       // to be interpreted as 0 and we can just skip the extra dimensions.
1366       for (unsigned i = 0, e = std::min(inferredShapeRank,
1367                                         subViewOp.getMixedOffsets().size());
1368            i < e; ++i) {
1369         Value offset =
1370             // TODO: need OpFoldResult ODS adaptor to clean this up.
1371             subViewOp.isDynamicOffset(i)
1372                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1373                 : rewriter.create<LLVM::ConstantOp>(
1374                       loc, llvmIndexType,
1375                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1376         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1377         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1378       }
1379       targetMemRef.setOffset(rewriter, loc, baseOffset);
1380     }
1381 
1382     // Update sizes and strides.
1383     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1384     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1385     assert(mixedSizes.size() == mixedStrides.size() &&
1386            "expected sizes and strides of equal length");
1387     llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
1388     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1389          i >= 0 && j >= 0; --i) {
1390       if (unusedDims.contains(i))
1391         continue;
1392 
1393       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1394       // In this case, the size is guaranteed to be interpreted as Dim and the
1395       // stride as 1.
1396       Value size, stride;
1397       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1398         // If the static size is available, use it directly. This is similar to
1399         // the folding of dim(constant-op) but removes the need for dim to be
1400         // aware of LLVM constants and for this pass to be aware of std
1401         // constants.
1402         int64_t staticSize =
1403             subViewOp.source().getType().cast<MemRefType>().getShape()[i];
1404         if (staticSize != ShapedType::kDynamicSize) {
1405           size = rewriter.create<LLVM::ConstantOp>(
1406               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1407         } else {
1408           Value pos = rewriter.create<LLVM::ConstantOp>(
1409               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1410           Value dim =
1411               rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos);
1412           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1413               loc, llvmIndexType, dim);
1414           size = cast.getResult(0);
1415         }
1416         stride = rewriter.create<LLVM::ConstantOp>(
1417             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1418       } else {
1419         // TODO: need OpFoldResult ODS adaptor to clean this up.
1420         size =
1421             subViewOp.isDynamicSize(i)
1422                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1423                 : rewriter.create<LLVM::ConstantOp>(
1424                       loc, llvmIndexType,
1425                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1426         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1427           stride = rewriter.create<LLVM::ConstantOp>(
1428               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1429         } else {
1430           stride =
1431               subViewOp.isDynamicStride(i)
1432                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1433                   : rewriter.create<LLVM::ConstantOp>(
1434                         loc, llvmIndexType,
1435                         rewriter.getI64IntegerAttr(
1436                             subViewOp.getStaticStride(i)));
1437           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1438         }
1439       }
1440       targetMemRef.setSize(rewriter, loc, j, size);
1441       targetMemRef.setStride(rewriter, loc, j, stride);
1442       j--;
1443     }
1444 
1445     rewriter.replaceOp(subViewOp, {targetMemRef});
1446     return success();
1447   }
1448 };
1449 
1450 /// Conversion pattern that transforms a transpose op into:
1451 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1452 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1453 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1454 ///      and stride. Size and stride are permutations of the original values.
1455 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1456 /// The transpose op is replaced by the alloca'ed pointer.
1457 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1458 public:
1459   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1460 
1461   LogicalResult
1462   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1463                   ConversionPatternRewriter &rewriter) const override {
1464     auto loc = transposeOp.getLoc();
1465     MemRefDescriptor viewMemRef(adaptor.in());
1466 
1467     // No permutation, early exit.
1468     if (transposeOp.permutation().isIdentity())
1469       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1470 
1471     auto targetMemRef = MemRefDescriptor::undef(
1472         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1473 
1474     // Copy the base and aligned pointers from the old descriptor to the new
1475     // one.
1476     targetMemRef.setAllocatedPtr(rewriter, loc,
1477                                  viewMemRef.allocatedPtr(rewriter, loc));
1478     targetMemRef.setAlignedPtr(rewriter, loc,
1479                                viewMemRef.alignedPtr(rewriter, loc));
1480 
1481     // Copy the offset pointer from the old descriptor to the new one.
1482     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1483 
1484     // Iterate over the dimensions and apply size/stride permutation.
1485     for (const auto &en :
1486          llvm::enumerate(transposeOp.permutation().getResults())) {
1487       int sourcePos = en.index();
1488       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1489       targetMemRef.setSize(rewriter, loc, targetPos,
1490                            viewMemRef.size(rewriter, loc, sourcePos));
1491       targetMemRef.setStride(rewriter, loc, targetPos,
1492                              viewMemRef.stride(rewriter, loc, sourcePos));
1493     }
1494 
1495     rewriter.replaceOp(transposeOp, {targetMemRef});
1496     return success();
1497   }
1498 };
1499 
1500 /// Conversion pattern that transforms an op into:
1501 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1502 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1503 ///      and stride.
1504 /// The view op is replaced by the descriptor.
1505 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1506   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1507 
1508   // Build and return the value for the idx^th shape dimension, either by
1509   // returning the constant shape dimension or counting the proper dynamic size.
1510   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1511                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1512                 unsigned idx) const {
1513     assert(idx < shape.size());
1514     if (!ShapedType::isDynamic(shape[idx]))
1515       return createIndexConstant(rewriter, loc, shape[idx]);
1516     // Count the number of dynamic dims in range [0, idx]
1517     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1518       return ShapedType::isDynamic(v);
1519     });
1520     return dynamicSizes[nDynamic];
1521   }
1522 
1523   // Build and return the idx^th stride, either by returning the constant stride
1524   // or by computing the dynamic stride from the current `runningStride` and
1525   // `nextSize`. The caller should keep a running stride and update it with the
1526   // result returned by this function.
1527   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1528                   ArrayRef<int64_t> strides, Value nextSize,
1529                   Value runningStride, unsigned idx) const {
1530     assert(idx < strides.size());
1531     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1532       return createIndexConstant(rewriter, loc, strides[idx]);
1533     if (nextSize)
1534       return runningStride
1535                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1536                  : nextSize;
1537     assert(!runningStride);
1538     return createIndexConstant(rewriter, loc, 1);
1539   }
1540 
1541   LogicalResult
1542   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1543                   ConversionPatternRewriter &rewriter) const override {
1544     auto loc = viewOp.getLoc();
1545 
1546     auto viewMemRefType = viewOp.getType();
1547     auto targetElementTy =
1548         typeConverter->convertType(viewMemRefType.getElementType());
1549     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1550     if (!targetDescTy || !targetElementTy ||
1551         !LLVM::isCompatibleType(targetElementTy) ||
1552         !LLVM::isCompatibleType(targetDescTy))
1553       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1554              failure();
1555 
1556     int64_t offset;
1557     SmallVector<int64_t, 4> strides;
1558     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1559     if (failed(successStrides))
1560       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1561     assert(offset == 0 && "expected offset to be 0");
1562 
1563     // Create the descriptor.
1564     MemRefDescriptor sourceMemRef(adaptor.source());
1565     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1566 
1567     // Field 1: Copy the allocated pointer, used for malloc/free.
1568     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1569     auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
1570     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1571         loc,
1572         LLVM::LLVMPointerType::get(targetElementTy,
1573                                    srcMemRefType.getMemorySpaceAsInt()),
1574         allocatedPtr);
1575     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1576 
1577     // Field 2: Copy the actual aligned pointer to payload.
1578     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1579     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
1580                                               alignedPtr, adaptor.byte_shift());
1581     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1582         loc,
1583         LLVM::LLVMPointerType::get(targetElementTy,
1584                                    srcMemRefType.getMemorySpaceAsInt()),
1585         alignedPtr);
1586     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1587 
1588     // Field 3: The offset in the resulting type must be 0. This is because of
1589     // the type change: an offset on srcType* may not be expressible as an
1590     // offset on dstType*.
1591     targetMemRef.setOffset(rewriter, loc,
1592                            createIndexConstant(rewriter, loc, offset));
1593 
1594     // Early exit for 0-D corner case.
1595     if (viewMemRefType.getRank() == 0)
1596       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1597 
1598     // Fields 4 and 5: Update sizes and strides.
1599     if (strides.back() != 1)
1600       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1601              failure();
1602     Value stride = nullptr, nextSize = nullptr;
1603     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1604       // Update size.
1605       Value size =
1606           getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
1607       targetMemRef.setSize(rewriter, loc, i, size);
1608       // Update stride.
1609       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1610       targetMemRef.setStride(rewriter, loc, i, stride);
1611       nextSize = size;
1612     }
1613 
1614     rewriter.replaceOp(viewOp, {targetMemRef});
1615     return success();
1616   }
1617 };
1618 
1619 //===----------------------------------------------------------------------===//
1620 // AtomicRMWOpLowering
1621 //===----------------------------------------------------------------------===//
1622 
1623 /// Try to match the kind of a std.atomic_rmw to determine whether to use a
1624 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1625 static Optional<LLVM::AtomicBinOp>
1626 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1627   switch (atomicOp.kind()) {
1628   case arith::AtomicRMWKind::addf:
1629     return LLVM::AtomicBinOp::fadd;
1630   case arith::AtomicRMWKind::addi:
1631     return LLVM::AtomicBinOp::add;
1632   case arith::AtomicRMWKind::assign:
1633     return LLVM::AtomicBinOp::xchg;
1634   case arith::AtomicRMWKind::maxs:
1635     return LLVM::AtomicBinOp::max;
1636   case arith::AtomicRMWKind::maxu:
1637     return LLVM::AtomicBinOp::umax;
1638   case arith::AtomicRMWKind::mins:
1639     return LLVM::AtomicBinOp::min;
1640   case arith::AtomicRMWKind::minu:
1641     return LLVM::AtomicBinOp::umin;
1642   case arith::AtomicRMWKind::ori:
1643     return LLVM::AtomicBinOp::_or;
1644   case arith::AtomicRMWKind::andi:
1645     return LLVM::AtomicBinOp::_and;
1646   default:
1647     return llvm::None;
1648   }
1649   llvm_unreachable("Invalid AtomicRMWKind");
1650 }
1651 
1652 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1653   using Base::Base;
1654 
1655   LogicalResult
1656   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1657                   ConversionPatternRewriter &rewriter) const override {
1658     if (failed(match(atomicOp)))
1659       return failure();
1660     auto maybeKind = matchSimpleAtomicOp(atomicOp);
1661     if (!maybeKind)
1662       return failure();
1663     auto resultType = adaptor.value().getType();
1664     auto memRefType = atomicOp.getMemRefType();
1665     auto dataPtr =
1666         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
1667                              adaptor.indices(), rewriter);
1668     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1669         atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
1670         LLVM::AtomicOrdering::acq_rel);
1671     return success();
1672   }
1673 };
1674 
1675 } // namespace
1676 
1677 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
1678                                                   RewritePatternSet &patterns) {
1679   // clang-format off
1680   patterns.add<
1681       AllocaOpLowering,
1682       AllocaScopeOpLowering,
1683       AtomicRMWOpLowering,
1684       AssumeAlignmentOpLowering,
1685       DimOpLowering,
1686       GlobalMemrefOpLowering,
1687       GetGlobalMemrefOpLowering,
1688       LoadOpLowering,
1689       MemRefCastOpLowering,
1690       MemRefCopyOpLowering,
1691       MemRefReinterpretCastOpLowering,
1692       MemRefReshapeOpLowering,
1693       PrefetchOpLowering,
1694       RankOpLowering,
1695       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1696       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1697       StoreOpLowering,
1698       SubViewOpLowering,
1699       TransposeOpLowering,
1700       ViewOpLowering>(converter);
1701   // clang-format on
1702   auto allocLowering = converter.getOptions().allocLowering;
1703   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1704     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1705   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1706     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1707 }
1708 
1709 namespace {
1710 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
1711   MemRefToLLVMPass() = default;
1712 
1713   void runOnOperation() override {
1714     Operation *op = getOperation();
1715     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1716     LowerToLLVMOptions options(&getContext(),
1717                                dataLayoutAnalysis.getAtOrAbove(op));
1718     options.allocLowering =
1719         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1720                          : LowerToLLVMOptions::AllocLowering::Malloc);
1721     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1722       options.overrideIndexBitwidth(indexBitwidth);
1723 
1724     LLVMTypeConverter typeConverter(&getContext(), options,
1725                                     &dataLayoutAnalysis);
1726     RewritePatternSet patterns(&getContext());
1727     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
1728     LLVMConversionTarget target(getContext());
1729     target.addLegalOp<FuncOp>();
1730     if (failed(applyPartialConversion(op, target, std::move(patterns))))
1731       signalPassFailure();
1732   }
1733 };
1734 } // namespace
1735 
1736 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
1737   return std::make_unique<MemRefToLLVMPass>();
1738 }
1739