1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
27   AllocOpLowering(LLVMTypeConverter &converter)
28       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
29                                 converter) {}
30 
31   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
32                                           Location loc, Value sizeBytes,
33                                           Operation *op) const override {
34     // Heap allocations.
35     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
36     MemRefType memRefType = allocOp.getType();
37 
38     Value alignment;
39     if (auto alignmentAttr = allocOp.alignment()) {
40       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
41     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
42       // In the case where no alignment is specified, we may want to override
43       // `malloc's` behavior. `malloc` typically aligns at the size of the
44       // biggest scalar on a target HW. For non-scalars, use the natural
45       // alignment of the LLVM type given by the LLVM DataLayout.
46       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
47     }
48 
49     if (alignment) {
50       // Adjust the allocation size to consider alignment.
51       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
52     }
53 
54     // Allocate the underlying buffer and store a pointer to it in the MemRef
55     // descriptor.
56     Type elementPtrType = this->getElementPtrType(memRefType);
57     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
58         allocOp->getParentOfType<ModuleOp>(), getIndexType());
59     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
60                                   getVoidPtrType());
61     Value allocatedPtr =
62         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
63 
64     Value alignedPtr = allocatedPtr;
65     if (alignment) {
66       // Compute the aligned type pointer.
67       Value allocatedInt =
68           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
69       Value alignmentInt =
70           createAligned(rewriter, loc, allocatedInt, alignment);
71       alignedPtr =
72           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
73     }
74 
75     return std::make_tuple(allocatedPtr, alignedPtr);
76   }
77 };
78 
79 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
80   AlignedAllocOpLowering(LLVMTypeConverter &converter)
81       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
82                                 converter) {}
83 
84   /// Returns the memref's element size in bytes using the data layout active at
85   /// `op`.
86   // TODO: there are other places where this is used. Expose publicly?
87   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
88     const DataLayout *layout = &defaultLayout;
89     if (const DataLayoutAnalysis *analysis =
90             getTypeConverter()->getDataLayoutAnalysis()) {
91       layout = &analysis->getAbove(op);
92     }
93     Type elementType = memRefType.getElementType();
94     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
95       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
96                                                          *layout);
97     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
98       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
99           memRefElementType, *layout);
100     return layout->getTypeSize(elementType);
101   }
102 
103   /// Returns true if the memref size in bytes is known to be a multiple of
104   /// factor assuming the data layout active at `op`.
105   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
106                               Operation *op) const {
107     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
108     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
109       if (ShapedType::isDynamic(type.getDimSize(i)))
110         continue;
111       sizeDivisor = sizeDivisor * type.getDimSize(i);
112     }
113     return sizeDivisor % factor == 0;
114   }
115 
116   /// Returns the alignment to be used for the allocation call itself.
117   /// aligned_alloc requires the allocation size to be a power of two, and the
118   /// allocation size to be a multiple of alignment,
119   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
120     if (Optional<uint64_t> alignment = allocOp.alignment())
121       return *alignment;
122 
123     // Whenever we don't have alignment set, we will use an alignment
124     // consistent with the element type; since the allocation size has to be a
125     // power of two, we will bump to the next power of two if it already isn't.
126     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
127     return std::max(kMinAlignedAllocAlignment,
128                     llvm::PowerOf2Ceil(eltSizeBytes));
129   }
130 
131   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
132                                           Location loc, Value sizeBytes,
133                                           Operation *op) const override {
134     // Heap allocations.
135     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
136     MemRefType memRefType = allocOp.getType();
137     int64_t alignment = getAllocationAlignment(allocOp);
138     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
139 
140     // aligned_alloc requires size to be a multiple of alignment; we will pad
141     // the size to the next multiple if necessary.
142     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
143       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
144 
145     Type elementPtrType = this->getElementPtrType(memRefType);
146     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
147         allocOp->getParentOfType<ModuleOp>(), getIndexType());
148     auto results =
149         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
150                        getVoidPtrType());
151     Value allocatedPtr =
152         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
153 
154     return std::make_tuple(allocatedPtr, allocatedPtr);
155   }
156 
157   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
158   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
159 
160   /// Default layout to use in absence of the corresponding analysis.
161   DataLayout defaultLayout;
162 };
163 
164 // Out of line definition, required till C++17.
165 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
166 
167 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
168   AllocaOpLowering(LLVMTypeConverter &converter)
169       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
170                                 converter) {}
171 
172   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
173   /// is set to null for stack allocations. `accessAlignment` is set if
174   /// alignment is needed post allocation (for eg. in conjunction with malloc).
175   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
176                                           Location loc, Value sizeBytes,
177                                           Operation *op) const override {
178 
179     // With alloca, one gets a pointer to the element type right away.
180     // For stack allocations.
181     auto allocaOp = cast<memref::AllocaOp>(op);
182     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
183 
184     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
185         loc, elementPtrType, sizeBytes,
186         allocaOp.alignment() ? *allocaOp.alignment() : 0);
187 
188     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
189   }
190 };
191 
192 struct AllocaScopeOpLowering
193     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
194   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
195 
196   LogicalResult
197   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
198                   ConversionPatternRewriter &rewriter) const override {
199     OpBuilder::InsertionGuard guard(rewriter);
200     Location loc = allocaScopeOp.getLoc();
201 
202     // Split the current block before the AllocaScopeOp to create the inlining
203     // point.
204     auto *currentBlock = rewriter.getInsertionBlock();
205     auto *remainingOpsBlock =
206         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
207     Block *continueBlock;
208     if (allocaScopeOp.getNumResults() == 0) {
209       continueBlock = remainingOpsBlock;
210     } else {
211       continueBlock = rewriter.createBlock(remainingOpsBlock,
212                                            allocaScopeOp.getResultTypes());
213       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
214     }
215 
216     // Inline body region.
217     Block *beforeBody = &allocaScopeOp.bodyRegion().front();
218     Block *afterBody = &allocaScopeOp.bodyRegion().back();
219     rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
220 
221     // Save stack and then branch into the body of the region.
222     rewriter.setInsertionPointToEnd(currentBlock);
223     auto stackSaveOp =
224         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
225     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
226 
227     // Replace the alloca_scope return with a branch that jumps out of the body.
228     // Stack restore before leaving the body region.
229     rewriter.setInsertionPointToEnd(afterBody);
230     auto returnOp =
231         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
232     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
233         returnOp, returnOp.results(), continueBlock);
234 
235     // Insert stack restore before jumping out the body of the region.
236     rewriter.setInsertionPoint(branchOp);
237     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
238 
239     // Replace the op with values return from the body region.
240     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
241 
242     return success();
243   }
244 };
245 
246 struct AssumeAlignmentOpLowering
247     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
248   using ConvertOpToLLVMPattern<
249       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
250 
251   LogicalResult
252   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
253                   ConversionPatternRewriter &rewriter) const override {
254     Value memref = adaptor.memref();
255     unsigned alignment = op.alignment();
256     auto loc = op.getLoc();
257 
258     MemRefDescriptor memRefDescriptor(memref);
259     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
260 
261     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
262     // the asserted memref.alignedPtr isn't used anywhere else, as the real
263     // users like load/store/views always re-extract memref.alignedPtr as they
264     // get lowered.
265     //
266     // This relies on LLVM's CSE optimization (potentially after SROA), since
267     // after CSE all memref.alignedPtr instances get de-duplicated into the same
268     // pointer SSA value.
269     auto intPtrType =
270         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
271     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
272     Value mask =
273         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
274     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
275     rewriter.create<LLVM::AssumeOp>(
276         loc, rewriter.create<LLVM::ICmpOp>(
277                  loc, LLVM::ICmpPredicate::eq,
278                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
279 
280     rewriter.eraseOp(op);
281     return success();
282   }
283 };
284 
285 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
286 // The memref descriptor being an SSA value, there is no need to clean it up
287 // in any way.
288 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
289   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
290 
291   explicit DeallocOpLowering(LLVMTypeConverter &converter)
292       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
293 
294   LogicalResult
295   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
296                   ConversionPatternRewriter &rewriter) const override {
297     // Insert the `free` declaration if it is not already present.
298     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
299     MemRefDescriptor memref(adaptor.memref());
300     Value casted = rewriter.create<LLVM::BitcastOp>(
301         op.getLoc(), getVoidPtrType(),
302         memref.allocatedPtr(rewriter, op.getLoc()));
303     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
304         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
305     return success();
306   }
307 };
308 
309 // A `dim` is converted to a constant for static sizes and to an access to the
310 // size stored in the memref descriptor for dynamic sizes.
311 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
312   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
313 
314   LogicalResult
315   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
316                   ConversionPatternRewriter &rewriter) const override {
317     Type operandType = dimOp.source().getType();
318     if (operandType.isa<UnrankedMemRefType>()) {
319       rewriter.replaceOp(
320           dimOp, {extractSizeOfUnrankedMemRef(
321                      operandType, dimOp, adaptor.getOperands(), rewriter)});
322 
323       return success();
324     }
325     if (operandType.isa<MemRefType>()) {
326       rewriter.replaceOp(
327           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
328                                             adaptor.getOperands(), rewriter)});
329       return success();
330     }
331     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
332   }
333 
334 private:
335   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
336                                     OpAdaptor adaptor,
337                                     ConversionPatternRewriter &rewriter) const {
338     Location loc = dimOp.getLoc();
339 
340     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
341     auto scalarMemRefType =
342         MemRefType::get({}, unrankedMemRefType.getElementType());
343     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
344 
345     // Extract pointer to the underlying ranked descriptor and bitcast it to a
346     // memref<element_type> descriptor pointer to minimize the number of GEP
347     // operations.
348     UnrankedMemRefDescriptor unrankedDesc(adaptor.source());
349     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
350     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
351         loc,
352         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
353                                    addressSpace),
354         underlyingRankedDesc);
355 
356     // Get pointer to offset field of memref<element_type> descriptor.
357     Type indexPtrTy = LLVM::LLVMPointerType::get(
358         getTypeConverter()->getIndexType(), addressSpace);
359     Value two = rewriter.create<LLVM::ConstantOp>(
360         loc, typeConverter->convertType(rewriter.getI32Type()),
361         rewriter.getI32IntegerAttr(2));
362     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
363         loc, indexPtrTy, scalarMemRefDescPtr,
364         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
365 
366     // The size value that we have to extract can be obtained using GEPop with
367     // `dimOp.index() + 1` index argument.
368     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
369         loc, createIndexConstant(rewriter, loc, 1), adaptor.index());
370     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
371                                                  ValueRange({idxPlusOne}));
372     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
373   }
374 
375   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
376     if (Optional<int64_t> idx = dimOp.getConstantIndex())
377       return idx;
378 
379     if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>())
380       return constantOp.getValue()
381           .cast<IntegerAttr>()
382           .getValue()
383           .getSExtValue();
384 
385     return llvm::None;
386   }
387 
388   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
389                                   OpAdaptor adaptor,
390                                   ConversionPatternRewriter &rewriter) const {
391     Location loc = dimOp.getLoc();
392 
393     // Take advantage if index is constant.
394     MemRefType memRefType = operandType.cast<MemRefType>();
395     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
396       int64_t i = index.getValue();
397       if (memRefType.isDynamicDim(i)) {
398         // extract dynamic size from the memref descriptor.
399         MemRefDescriptor descriptor(adaptor.source());
400         return descriptor.size(rewriter, loc, i);
401       }
402       // Use constant for static size.
403       int64_t dimSize = memRefType.getDimSize(i);
404       return createIndexConstant(rewriter, loc, dimSize);
405     }
406     Value index = adaptor.index();
407     int64_t rank = memRefType.getRank();
408     MemRefDescriptor memrefDescriptor(adaptor.source());
409     return memrefDescriptor.size(rewriter, loc, index, rank);
410   }
411 };
412 
413 /// Returns the LLVM type of the global variable given the memref type `type`.
414 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
415                                           LLVMTypeConverter &typeConverter) {
416   // LLVM type for a global memref will be a multi-dimension array. For
417   // declarations or uninitialized global memrefs, we can potentially flatten
418   // this to a 1D array. However, for memref.global's with an initial value,
419   // we do not intend to flatten the ElementsAttribute when going from std ->
420   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
421   Type elementType = typeConverter.convertType(type.getElementType());
422   Type arrayTy = elementType;
423   // Shape has the outermost dim at index 0, so need to walk it backwards
424   for (int64_t dim : llvm::reverse(type.getShape()))
425     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
426   return arrayTy;
427 }
428 
429 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
430 struct GlobalMemrefOpLowering
431     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
432   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
433 
434   LogicalResult
435   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
436                   ConversionPatternRewriter &rewriter) const override {
437     MemRefType type = global.type();
438     if (!isConvertibleAndHasIdentityMaps(type))
439       return failure();
440 
441     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
442 
443     LLVM::Linkage linkage =
444         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
445 
446     Attribute initialValue = nullptr;
447     if (!global.isExternal() && !global.isUninitialized()) {
448       auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
449       initialValue = elementsAttr;
450 
451       // For scalar memrefs, the global variable created is of the element type,
452       // so unpack the elements attribute to extract the value.
453       if (type.getRank() == 0)
454         initialValue = elementsAttr.getSplatValue<Attribute>();
455     }
456 
457     uint64_t alignment = global.alignment().getValueOr(0);
458 
459     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
460         global, arrayTy, global.constant(), linkage, global.sym_name(),
461         initialValue, alignment, type.getMemorySpaceAsInt());
462     if (!global.isExternal() && global.isUninitialized()) {
463       Block *blk = new Block();
464       newGlobal.getInitializerRegion().push_back(blk);
465       rewriter.setInsertionPointToStart(blk);
466       Value undef[] = {
467           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
468       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
469     }
470     return success();
471   }
472 };
473 
474 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
475 /// the first element stashed into the descriptor. This reuses
476 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
477 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
478   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
479       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
480                                 converter) {}
481 
482   /// Buffer "allocation" for memref.get_global op is getting the address of
483   /// the global variable referenced.
484   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
485                                           Location loc, Value sizeBytes,
486                                           Operation *op) const override {
487     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
488     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
489     unsigned memSpace = type.getMemorySpaceAsInt();
490 
491     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
492     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
493         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
494 
495     // Get the address of the first element in the array by creating a GEP with
496     // the address of the GV as the base, and (rank + 1) number of 0 indices.
497     Type elementType = typeConverter->convertType(type.getElementType());
498     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
499 
500     SmallVector<Value> operands;
501     operands.insert(operands.end(), type.getRank() + 1,
502                     createIndexConstant(rewriter, loc, 0));
503     auto gep =
504         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
505 
506     // We do not expect the memref obtained using `memref.get_global` to be
507     // ever deallocated. Set the allocated pointer to be known bad value to
508     // help debug if that ever happens.
509     auto intPtrType = getIntPtrType(memSpace);
510     Value deadBeefConst =
511         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
512     auto deadBeefPtr =
513         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
514 
515     // Both allocated and aligned pointers are same. We could potentially stash
516     // a nullptr for the allocated pointer since we do not expect any dealloc.
517     return std::make_tuple(deadBeefPtr, gep);
518   }
519 };
520 
521 // Common base for load and store operations on MemRefs. Restricts the match
522 // to supported MemRef types. Provides functionality to emit code accessing a
523 // specific element of the underlying data buffer.
524 template <typename Derived>
525 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
526   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
527   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
528   using Base = LoadStoreOpLowering<Derived>;
529 
530   LogicalResult match(Derived op) const override {
531     MemRefType type = op.getMemRefType();
532     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
533   }
534 };
535 
536 // Load operation is lowered to obtaining a pointer to the indexed element
537 // and loading it.
538 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
539   using Base::Base;
540 
541   LogicalResult
542   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
543                   ConversionPatternRewriter &rewriter) const override {
544     auto type = loadOp.getMemRefType();
545 
546     Value dataPtr = getStridedElementPtr(
547         loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter);
548     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
549     return success();
550   }
551 };
552 
553 // Store operation is lowered to obtaining a pointer to the indexed element,
554 // and storing the given value to it.
555 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
556   using Base::Base;
557 
558   LogicalResult
559   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
560                   ConversionPatternRewriter &rewriter) const override {
561     auto type = op.getMemRefType();
562 
563     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(),
564                                          adaptor.indices(), rewriter);
565     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr);
566     return success();
567   }
568 };
569 
570 // The prefetch operation is lowered in a way similar to the load operation
571 // except that the llvm.prefetch operation is used for replacement.
572 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
573   using Base::Base;
574 
575   LogicalResult
576   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
577                   ConversionPatternRewriter &rewriter) const override {
578     auto type = prefetchOp.getMemRefType();
579     auto loc = prefetchOp.getLoc();
580 
581     Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(),
582                                          adaptor.indices(), rewriter);
583 
584     // Replace with llvm.prefetch.
585     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
586     auto isWrite = rewriter.create<LLVM::ConstantOp>(
587         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
588     auto localityHint = rewriter.create<LLVM::ConstantOp>(
589         loc, llvmI32Type,
590         rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
591     auto isData = rewriter.create<LLVM::ConstantOp>(
592         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
593 
594     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
595                                                 localityHint, isData);
596     return success();
597   }
598 };
599 
600 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
601   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
602 
603   LogicalResult
604   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
605                   ConversionPatternRewriter &rewriter) const override {
606     Location loc = op.getLoc();
607     Type operandType = op.memref().getType();
608     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
609       UnrankedMemRefDescriptor desc(adaptor.memref());
610       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
611       return success();
612     }
613     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
614       rewriter.replaceOp(
615           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
616       return success();
617     }
618     return failure();
619   }
620 };
621 
622 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
623   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
624 
625   LogicalResult match(memref::CastOp memRefCastOp) const override {
626     Type srcType = memRefCastOp.getOperand().getType();
627     Type dstType = memRefCastOp.getType();
628 
629     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
630     // used for type erasure. For now they must preserve underlying element type
631     // and require source and result type to have the same rank. Therefore,
632     // perform a sanity check that the underlying structs are the same. Once op
633     // semantics are relaxed we can revisit.
634     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
635       return success(typeConverter->convertType(srcType) ==
636                      typeConverter->convertType(dstType));
637 
638     // At least one of the operands is unranked type
639     assert(srcType.isa<UnrankedMemRefType>() ||
640            dstType.isa<UnrankedMemRefType>());
641 
642     // Unranked to unranked cast is disallowed
643     return !(srcType.isa<UnrankedMemRefType>() &&
644              dstType.isa<UnrankedMemRefType>())
645                ? success()
646                : failure();
647   }
648 
649   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
650                ConversionPatternRewriter &rewriter) const override {
651     auto srcType = memRefCastOp.getOperand().getType();
652     auto dstType = memRefCastOp.getType();
653     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
654     auto loc = memRefCastOp.getLoc();
655 
656     // For ranked/ranked case, just keep the original descriptor.
657     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
658       return rewriter.replaceOp(memRefCastOp, {adaptor.source()});
659 
660     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
661       // Casting ranked to unranked memref type
662       // Set the rank in the destination from the memref type
663       // Allocate space on the stack and copy the src memref descriptor
664       // Set the ptr in the destination to the stack space
665       auto srcMemRefType = srcType.cast<MemRefType>();
666       int64_t rank = srcMemRefType.getRank();
667       // ptr = AllocaOp sizeof(MemRefDescriptor)
668       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
669           loc, adaptor.source(), rewriter);
670       // voidptr = BitCastOp srcType* to void*
671       auto voidPtr =
672           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
673               .getResult();
674       // rank = ConstantOp srcRank
675       auto rankVal = rewriter.create<LLVM::ConstantOp>(
676           loc, typeConverter->convertType(rewriter.getIntegerType(64)),
677           rewriter.getI64IntegerAttr(rank));
678       // undef = UndefOp
679       UnrankedMemRefDescriptor memRefDesc =
680           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
681       // d1 = InsertValueOp undef, rank, 0
682       memRefDesc.setRank(rewriter, loc, rankVal);
683       // d2 = InsertValueOp d1, voidptr, 1
684       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
685       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
686 
687     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
688       // Casting from unranked type to ranked.
689       // The operation is assumed to be doing a correct cast. If the destination
690       // type mismatches the unranked the type, it is undefined behavior.
691       UnrankedMemRefDescriptor memRefDesc(adaptor.source());
692       // ptr = ExtractValueOp src, 1
693       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
694       // castPtr = BitCastOp i8* to structTy*
695       auto castPtr =
696           rewriter
697               .create<LLVM::BitcastOp>(
698                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
699               .getResult();
700       // struct = LoadOp castPtr
701       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
702       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
703     } else {
704       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
705     }
706   }
707 };
708 
709 /// Pattern to lower a `memref.copy` to llvm.
710 ///
711 /// For memrefs with identity layouts, the copy is lowered to the llvm
712 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
713 /// to the generic `MemrefCopyFn`.
714 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
715   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
716 
717   LogicalResult
718   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
719                           ConversionPatternRewriter &rewriter) const {
720     auto loc = op.getLoc();
721     auto srcType = op.source().getType().dyn_cast<MemRefType>();
722 
723     MemRefDescriptor srcDesc(adaptor.source());
724 
725     // Compute number of elements.
726     Value numElements = rewriter.create<LLVM::ConstantOp>(
727         loc, getIndexType(), rewriter.getIndexAttr(1));
728     for (int pos = 0; pos < srcType.getRank(); ++pos) {
729       auto size = srcDesc.size(rewriter, loc, pos);
730       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
731     }
732 
733     // Get element size.
734     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
735     // Compute total.
736     Value totalSize =
737         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
738 
739     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
740     MemRefDescriptor targetDesc(adaptor.target());
741     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
742     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
743         loc, typeConverter->convertType(rewriter.getI1Type()),
744         rewriter.getBoolAttr(false));
745     rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize,
746                                     isVolatile);
747     rewriter.eraseOp(op);
748 
749     return success();
750   }
751 
752   LogicalResult
753   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
754                              ConversionPatternRewriter &rewriter) const {
755     auto loc = op.getLoc();
756     auto srcType = op.source().getType().cast<BaseMemRefType>();
757     auto targetType = op.target().getType().cast<BaseMemRefType>();
758 
759     // First make sure we have an unranked memref descriptor representation.
760     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
761       auto rank = rewriter.create<LLVM::ConstantOp>(
762           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
763       auto *typeConverter = getTypeConverter();
764       auto ptr =
765           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
766       auto voidPtr =
767           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
768               .getResult();
769       auto unrankedType =
770           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
771       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
772                                             unrankedType,
773                                             ValueRange{rank, voidPtr});
774     };
775 
776     Value unrankedSource = srcType.hasRank()
777                                ? makeUnranked(adaptor.source(), srcType)
778                                : adaptor.source();
779     Value unrankedTarget = targetType.hasRank()
780                                ? makeUnranked(adaptor.target(), targetType)
781                                : adaptor.target();
782 
783     // Now promote the unranked descriptors to the stack.
784     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
785                                                  rewriter.getIndexAttr(1));
786     auto promote = [&](Value desc) {
787       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
788       auto allocated =
789           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
790       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
791       return allocated;
792     };
793 
794     auto sourcePtr = promote(unrankedSource);
795     auto targetPtr = promote(unrankedTarget);
796 
797     auto elemSize = rewriter.create<LLVM::ConstantOp>(
798         loc, getIndexType(),
799         rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8));
800     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
801         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
802     rewriter.create<LLVM::CallOp>(loc, copyFn,
803                                   ValueRange{elemSize, sourcePtr, targetPtr});
804     rewriter.eraseOp(op);
805 
806     return success();
807   }
808 
809   LogicalResult
810   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
811                   ConversionPatternRewriter &rewriter) const override {
812     auto srcType = op.source().getType().cast<BaseMemRefType>();
813     auto targetType = op.target().getType().cast<BaseMemRefType>();
814 
815     if (srcType.hasRank() &&
816         srcType.cast<MemRefType>().getLayout().isIdentity() &&
817         targetType.hasRank() &&
818         targetType.cast<MemRefType>().getLayout().isIdentity())
819       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
820 
821     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
822   }
823 };
824 
825 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
826 /// memref type. In unranked case, the fields are extracted from the underlying
827 /// ranked descriptor.
828 static void extractPointersAndOffset(Location loc,
829                                      ConversionPatternRewriter &rewriter,
830                                      LLVMTypeConverter &typeConverter,
831                                      Value originalOperand,
832                                      Value convertedOperand,
833                                      Value *allocatedPtr, Value *alignedPtr,
834                                      Value *offset = nullptr) {
835   Type operandType = originalOperand.getType();
836   if (operandType.isa<MemRefType>()) {
837     MemRefDescriptor desc(convertedOperand);
838     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
839     *alignedPtr = desc.alignedPtr(rewriter, loc);
840     if (offset != nullptr)
841       *offset = desc.offset(rewriter, loc);
842     return;
843   }
844 
845   unsigned memorySpace =
846       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
847   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
848   Type llvmElementType = typeConverter.convertType(elementType);
849   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
850       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
851 
852   // Extract pointer to the underlying ranked memref descriptor and cast it to
853   // ElemType**.
854   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
855   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
856 
857   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
858       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
859   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
860       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
861   if (offset != nullptr) {
862     *offset = UnrankedMemRefDescriptor::offset(
863         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
864   }
865 }
866 
867 struct MemRefReinterpretCastOpLowering
868     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
869   using ConvertOpToLLVMPattern<
870       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
871 
872   LogicalResult
873   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
874                   ConversionPatternRewriter &rewriter) const override {
875     Type srcType = castOp.source().getType();
876 
877     Value descriptor;
878     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
879                                                adaptor, &descriptor)))
880       return failure();
881     rewriter.replaceOp(castOp, {descriptor});
882     return success();
883   }
884 
885 private:
886   LogicalResult convertSourceMemRefToDescriptor(
887       ConversionPatternRewriter &rewriter, Type srcType,
888       memref::ReinterpretCastOp castOp,
889       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
890     MemRefType targetMemRefType =
891         castOp.getResult().getType().cast<MemRefType>();
892     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
893                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
894     if (!llvmTargetDescriptorTy)
895       return failure();
896 
897     // Create descriptor.
898     Location loc = castOp.getLoc();
899     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
900 
901     // Set allocated and aligned pointers.
902     Value allocatedPtr, alignedPtr;
903     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
904                              castOp.source(), adaptor.source(), &allocatedPtr,
905                              &alignedPtr);
906     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
907     desc.setAlignedPtr(rewriter, loc, alignedPtr);
908 
909     // Set offset.
910     if (castOp.isDynamicOffset(0))
911       desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
912     else
913       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
914 
915     // Set sizes and strides.
916     unsigned dynSizeId = 0;
917     unsigned dynStrideId = 0;
918     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
919       if (castOp.isDynamicSize(i))
920         desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
921       else
922         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
923 
924       if (castOp.isDynamicStride(i))
925         desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
926       else
927         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
928     }
929     *descriptor = desc;
930     return success();
931   }
932 };
933 
934 struct MemRefReshapeOpLowering
935     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
936   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
937 
938   LogicalResult
939   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
940                   ConversionPatternRewriter &rewriter) const override {
941     Type srcType = reshapeOp.source().getType();
942 
943     Value descriptor;
944     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
945                                                adaptor, &descriptor)))
946       return failure();
947     rewriter.replaceOp(reshapeOp, {descriptor});
948     return success();
949   }
950 
951 private:
952   LogicalResult
953   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
954                                   Type srcType, memref::ReshapeOp reshapeOp,
955                                   memref::ReshapeOp::Adaptor adaptor,
956                                   Value *descriptor) const {
957     // Conversion for statically-known shape args is performed via
958     // `memref_reinterpret_cast`.
959     auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
960     if (shapeMemRefType.hasStaticShape())
961       return failure();
962 
963     // The shape is a rank-1 tensor with unknown length.
964     Location loc = reshapeOp.getLoc();
965     MemRefDescriptor shapeDesc(adaptor.shape());
966     Value resultRank = shapeDesc.size(rewriter, loc, 0);
967 
968     // Extract address space and element type.
969     auto targetType =
970         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
971     unsigned addressSpace = targetType.getMemorySpaceAsInt();
972     Type elementType = targetType.getElementType();
973 
974     // Create the unranked memref descriptor that holds the ranked one. The
975     // inner descriptor is allocated on stack.
976     auto targetDesc = UnrankedMemRefDescriptor::undef(
977         rewriter, loc, typeConverter->convertType(targetType));
978     targetDesc.setRank(rewriter, loc, resultRank);
979     SmallVector<Value, 4> sizes;
980     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
981                                            targetDesc, sizes);
982     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
983         loc, getVoidPtrType(), sizes.front(), llvm::None);
984     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
985 
986     // Extract pointers and offset from the source memref.
987     Value allocatedPtr, alignedPtr, offset;
988     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
989                              reshapeOp.source(), adaptor.source(),
990                              &allocatedPtr, &alignedPtr, &offset);
991 
992     // Set pointers and offset.
993     Type llvmElementType = typeConverter->convertType(elementType);
994     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
995         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
996     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
997                                               elementPtrPtrType, allocatedPtr);
998     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
999                                             underlyingDescPtr,
1000                                             elementPtrPtrType, alignedPtr);
1001     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1002                                         underlyingDescPtr, elementPtrPtrType,
1003                                         offset);
1004 
1005     // Use the offset pointer as base for further addressing. Copy over the new
1006     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1007     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1008         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1009         elementPtrPtrType);
1010     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1011         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1012     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1013     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1014     Value resultRankMinusOne =
1015         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1016 
1017     Block *initBlock = rewriter.getInsertionBlock();
1018     Type indexType = getTypeConverter()->getIndexType();
1019     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1020 
1021     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1022                                             {indexType, indexType});
1023 
1024     // Move the remaining initBlock ops to condBlock.
1025     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1026     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1027 
1028     rewriter.setInsertionPointToEnd(initBlock);
1029     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1030                                 condBlock);
1031     rewriter.setInsertionPointToStart(condBlock);
1032     Value indexArg = condBlock->getArgument(0);
1033     Value strideArg = condBlock->getArgument(1);
1034 
1035     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1036     Value pred = rewriter.create<LLVM::ICmpOp>(
1037         loc, IntegerType::get(rewriter.getContext(), 1),
1038         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1039 
1040     Block *bodyBlock =
1041         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1042     rewriter.setInsertionPointToStart(bodyBlock);
1043 
1044     // Copy size from shape to descriptor.
1045     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1046     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1047         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1048     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1049     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1050                                       targetSizesBase, indexArg, size);
1051 
1052     // Write stride value and compute next one.
1053     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1054                                         targetStridesBase, indexArg, strideArg);
1055     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1056 
1057     // Decrement loop counter and branch back.
1058     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1059     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1060                                 condBlock);
1061 
1062     Block *remainder =
1063         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1064 
1065     // Hook up the cond exit to the remainder.
1066     rewriter.setInsertionPointToEnd(condBlock);
1067     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1068                                     llvm::None);
1069 
1070     // Reset position to beginning of new remainder block.
1071     rewriter.setInsertionPointToStart(remainder);
1072 
1073     *descriptor = targetDesc;
1074     return success();
1075   }
1076 };
1077 
1078 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1079 /// `Value`s.
1080 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1081                                       Type &llvmIndexType,
1082                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1083   return llvm::to_vector<4>(
1084       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1085         if (auto attr = value.dyn_cast<Attribute>())
1086           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1087         return value.get<Value>();
1088       }));
1089 }
1090 
1091 /// Compute a map that for a given dimension of the expanded type gives the
1092 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1093 /// the `reassocation` maps.
1094 static DenseMap<int64_t, int64_t>
1095 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1096   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1097   for (auto &en : enumerate(reassociation)) {
1098     for (auto dim : en.value())
1099       expandedDimToCollapsedDim[dim] = en.index();
1100   }
1101   return expandedDimToCollapsedDim;
1102 }
1103 
1104 static OpFoldResult
1105 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1106                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1107                          MemRefDescriptor &inDesc,
1108                          ArrayRef<int64_t> inStaticShape,
1109                          ArrayRef<ReassociationIndices> reassocation,
1110                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1111   int64_t outDimSize = outStaticShape[outDimIndex];
1112   if (!ShapedType::isDynamic(outDimSize))
1113     return b.getIndexAttr(outDimSize);
1114 
1115   // Calculate the multiplication of all the out dim sizes except the
1116   // current dim.
1117   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1118   int64_t otherDimSizesMul = 1;
1119   for (auto otherDimIndex : reassocation[inDimIndex]) {
1120     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1121       continue;
1122     int64_t otherDimSize = outStaticShape[otherDimIndex];
1123     assert(!ShapedType::isDynamic(otherDimSize) &&
1124            "single dimension cannot be expanded into multiple dynamic "
1125            "dimensions");
1126     otherDimSizesMul *= otherDimSize;
1127   }
1128 
1129   // outDimSize = inDimSize / otherOutDimSizesMul
1130   int64_t inDimSize = inStaticShape[inDimIndex];
1131   Value inDimSizeDynamic =
1132       ShapedType::isDynamic(inDimSize)
1133           ? inDesc.size(b, loc, inDimIndex)
1134           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1135                                        b.getIndexAttr(inDimSize));
1136   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1137       loc, inDimSizeDynamic,
1138       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1139                                  b.getIndexAttr(otherDimSizesMul)));
1140   return outDimSizeDynamic;
1141 }
1142 
1143 static OpFoldResult getCollapsedOutputDimSize(
1144     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1145     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1146     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1147   if (!ShapedType::isDynamic(outDimSize))
1148     return b.getIndexAttr(outDimSize);
1149 
1150   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1151   Value outDimSizeDynamic = c1;
1152   for (auto inDimIndex : reassocation[outDimIndex]) {
1153     int64_t inDimSize = inStaticShape[inDimIndex];
1154     Value inDimSizeDynamic =
1155         ShapedType::isDynamic(inDimSize)
1156             ? inDesc.size(b, loc, inDimIndex)
1157             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1158                                          b.getIndexAttr(inDimSize));
1159     outDimSizeDynamic =
1160         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1161   }
1162   return outDimSizeDynamic;
1163 }
1164 
1165 static SmallVector<OpFoldResult, 4>
1166 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1167                         ArrayRef<ReassociationIndices> reassocation,
1168                         ArrayRef<int64_t> inStaticShape,
1169                         MemRefDescriptor &inDesc,
1170                         ArrayRef<int64_t> outStaticShape) {
1171   return llvm::to_vector<4>(llvm::map_range(
1172       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1173         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1174                                          outStaticShape[outDimIndex],
1175                                          inStaticShape, inDesc, reassocation);
1176       }));
1177 }
1178 
1179 static SmallVector<OpFoldResult, 4>
1180 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1181                        ArrayRef<ReassociationIndices> reassocation,
1182                        ArrayRef<int64_t> inStaticShape,
1183                        MemRefDescriptor &inDesc,
1184                        ArrayRef<int64_t> outStaticShape) {
1185   DenseMap<int64_t, int64_t> outDimToInDimMap =
1186       getExpandedDimToCollapsedDimMap(reassocation);
1187   return llvm::to_vector<4>(llvm::map_range(
1188       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1189         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1190                                         outStaticShape, inDesc, inStaticShape,
1191                                         reassocation, outDimToInDimMap);
1192       }));
1193 }
1194 
1195 static SmallVector<Value>
1196 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1197                       ArrayRef<ReassociationIndices> reassocation,
1198                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1199                       ArrayRef<int64_t> outStaticShape) {
1200   return outStaticShape.size() < inStaticShape.size()
1201              ? getAsValues(b, loc, llvmIndexType,
1202                            getCollapsedOutputShape(b, loc, llvmIndexType,
1203                                                    reassocation, inStaticShape,
1204                                                    inDesc, outStaticShape))
1205              : getAsValues(b, loc, llvmIndexType,
1206                            getExpandedOutputShape(b, loc, llvmIndexType,
1207                                                   reassocation, inStaticShape,
1208                                                   inDesc, outStaticShape));
1209 }
1210 
1211 // ReshapeOp creates a new view descriptor of the proper rank.
1212 // For now, the only conversion supported is for target MemRef with static sizes
1213 // and strides.
1214 template <typename ReshapeOp>
1215 class ReassociatingReshapeOpConversion
1216     : public ConvertOpToLLVMPattern<ReshapeOp> {
1217 public:
1218   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1219   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1220 
1221   LogicalResult
1222   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1223                   ConversionPatternRewriter &rewriter) const override {
1224     MemRefType dstType = reshapeOp.getResultType();
1225     MemRefType srcType = reshapeOp.getSrcType();
1226 
1227     // The condition on the layouts can be ignored when all shapes are static.
1228     if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) {
1229       if (!srcType.getLayout().isIdentity() ||
1230           !dstType.getLayout().isIdentity()) {
1231         return rewriter.notifyMatchFailure(
1232             reshapeOp, "only empty layout map is supported");
1233       }
1234     }
1235 
1236     int64_t offset;
1237     SmallVector<int64_t, 4> strides;
1238     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1239       return rewriter.notifyMatchFailure(
1240           reshapeOp, "failed to get stride and offset exprs");
1241     }
1242 
1243     MemRefDescriptor srcDesc(adaptor.src());
1244     Location loc = reshapeOp->getLoc();
1245     auto dstDesc = MemRefDescriptor::undef(
1246         rewriter, loc, this->typeConverter->convertType(dstType));
1247     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1248     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1249     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1250 
1251     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1252     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1253     Type llvmIndexType =
1254         this->typeConverter->convertType(rewriter.getIndexType());
1255     SmallVector<Value> dstShape = getDynamicOutputShape(
1256         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1257         srcStaticShape, srcDesc, dstStaticShape);
1258     for (auto &en : llvm::enumerate(dstShape))
1259       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1260 
1261     auto isStaticStride = [](int64_t stride) {
1262       return !ShapedType::isDynamicStrideOrOffset(stride);
1263     };
1264     if (llvm::all_of(strides, isStaticStride)) {
1265       for (auto &en : llvm::enumerate(strides))
1266         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1267     } else {
1268       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1269                                                    rewriter.getIndexAttr(1));
1270       Value stride = c1;
1271       for (auto dimIndex :
1272            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1273         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1274         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1275       }
1276     }
1277     rewriter.replaceOp(reshapeOp, {dstDesc});
1278     return success();
1279   }
1280 };
1281 
1282 /// Conversion pattern that transforms a subview op into:
1283 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1284 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1285 ///      and stride.
1286 /// The subview op is replaced by the descriptor.
1287 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1288   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1289 
1290   LogicalResult
1291   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1292                   ConversionPatternRewriter &rewriter) const override {
1293     auto loc = subViewOp.getLoc();
1294 
1295     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
1296     auto sourceElementTy =
1297         typeConverter->convertType(sourceMemRefType.getElementType());
1298 
1299     auto viewMemRefType = subViewOp.getType();
1300     auto inferredType = memref::SubViewOp::inferResultType(
1301                             subViewOp.getSourceType(),
1302                             extractFromI64ArrayAttr(subViewOp.static_offsets()),
1303                             extractFromI64ArrayAttr(subViewOp.static_sizes()),
1304                             extractFromI64ArrayAttr(subViewOp.static_strides()))
1305                             .cast<MemRefType>();
1306     auto targetElementTy =
1307         typeConverter->convertType(viewMemRefType.getElementType());
1308     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1309     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1310         !LLVM::isCompatibleType(sourceElementTy) ||
1311         !LLVM::isCompatibleType(targetElementTy) ||
1312         !LLVM::isCompatibleType(targetDescTy))
1313       return failure();
1314 
1315     // Extract the offset and strides from the type.
1316     int64_t offset;
1317     SmallVector<int64_t, 4> strides;
1318     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1319     if (failed(successStrides))
1320       return failure();
1321 
1322     // Create the descriptor.
1323     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1324       return failure();
1325     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1326     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1327 
1328     // Copy the buffer pointer from the old descriptor to the new one.
1329     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1330     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1331         loc,
1332         LLVM::LLVMPointerType::get(targetElementTy,
1333                                    viewMemRefType.getMemorySpaceAsInt()),
1334         extracted);
1335     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1336 
1337     // Copy the aligned pointer from the old descriptor to the new one.
1338     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1339     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1340         loc,
1341         LLVM::LLVMPointerType::get(targetElementTy,
1342                                    viewMemRefType.getMemorySpaceAsInt()),
1343         extracted);
1344     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1345 
1346     size_t inferredShapeRank = inferredType.getRank();
1347     size_t resultShapeRank = viewMemRefType.getRank();
1348 
1349     // Extract strides needed to compute offset.
1350     SmallVector<Value, 4> strideValues;
1351     strideValues.reserve(inferredShapeRank);
1352     for (unsigned i = 0; i < inferredShapeRank; ++i)
1353       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1354 
1355     // Offset.
1356     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1357     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1358       targetMemRef.setConstantOffset(rewriter, loc, offset);
1359     } else {
1360       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1361       // `inferredShapeRank` may be larger than the number of offset operands
1362       // because of trailing semantics. In this case, the offset is guaranteed
1363       // to be interpreted as 0 and we can just skip the extra dimensions.
1364       for (unsigned i = 0, e = std::min(inferredShapeRank,
1365                                         subViewOp.getMixedOffsets().size());
1366            i < e; ++i) {
1367         Value offset =
1368             // TODO: need OpFoldResult ODS adaptor to clean this up.
1369             subViewOp.isDynamicOffset(i)
1370                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1371                 : rewriter.create<LLVM::ConstantOp>(
1372                       loc, llvmIndexType,
1373                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1374         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1375         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1376       }
1377       targetMemRef.setOffset(rewriter, loc, baseOffset);
1378     }
1379 
1380     // Update sizes and strides.
1381     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1382     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1383     assert(mixedSizes.size() == mixedStrides.size() &&
1384            "expected sizes and strides of equal length");
1385     llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
1386     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1387          i >= 0 && j >= 0; --i) {
1388       if (unusedDims.contains(i))
1389         continue;
1390 
1391       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1392       // In this case, the size is guaranteed to be interpreted as Dim and the
1393       // stride as 1.
1394       Value size, stride;
1395       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1396         // If the static size is available, use it directly. This is similar to
1397         // the folding of dim(constant-op) but removes the need for dim to be
1398         // aware of LLVM constants and for this pass to be aware of std
1399         // constants.
1400         int64_t staticSize =
1401             subViewOp.source().getType().cast<MemRefType>().getShape()[i];
1402         if (staticSize != ShapedType::kDynamicSize) {
1403           size = rewriter.create<LLVM::ConstantOp>(
1404               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1405         } else {
1406           Value pos = rewriter.create<LLVM::ConstantOp>(
1407               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1408           Value dim =
1409               rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos);
1410           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1411               loc, llvmIndexType, dim);
1412           size = cast.getResult(0);
1413         }
1414         stride = rewriter.create<LLVM::ConstantOp>(
1415             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1416       } else {
1417         // TODO: need OpFoldResult ODS adaptor to clean this up.
1418         size =
1419             subViewOp.isDynamicSize(i)
1420                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1421                 : rewriter.create<LLVM::ConstantOp>(
1422                       loc, llvmIndexType,
1423                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1424         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1425           stride = rewriter.create<LLVM::ConstantOp>(
1426               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1427         } else {
1428           stride =
1429               subViewOp.isDynamicStride(i)
1430                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1431                   : rewriter.create<LLVM::ConstantOp>(
1432                         loc, llvmIndexType,
1433                         rewriter.getI64IntegerAttr(
1434                             subViewOp.getStaticStride(i)));
1435           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1436         }
1437       }
1438       targetMemRef.setSize(rewriter, loc, j, size);
1439       targetMemRef.setStride(rewriter, loc, j, stride);
1440       j--;
1441     }
1442 
1443     rewriter.replaceOp(subViewOp, {targetMemRef});
1444     return success();
1445   }
1446 };
1447 
1448 /// Conversion pattern that transforms a transpose op into:
1449 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1450 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1451 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1452 ///      and stride. Size and stride are permutations of the original values.
1453 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1454 /// The transpose op is replaced by the alloca'ed pointer.
1455 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1456 public:
1457   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1458 
1459   LogicalResult
1460   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1461                   ConversionPatternRewriter &rewriter) const override {
1462     auto loc = transposeOp.getLoc();
1463     MemRefDescriptor viewMemRef(adaptor.in());
1464 
1465     // No permutation, early exit.
1466     if (transposeOp.permutation().isIdentity())
1467       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1468 
1469     auto targetMemRef = MemRefDescriptor::undef(
1470         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1471 
1472     // Copy the base and aligned pointers from the old descriptor to the new
1473     // one.
1474     targetMemRef.setAllocatedPtr(rewriter, loc,
1475                                  viewMemRef.allocatedPtr(rewriter, loc));
1476     targetMemRef.setAlignedPtr(rewriter, loc,
1477                                viewMemRef.alignedPtr(rewriter, loc));
1478 
1479     // Copy the offset pointer from the old descriptor to the new one.
1480     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1481 
1482     // Iterate over the dimensions and apply size/stride permutation.
1483     for (const auto &en :
1484          llvm::enumerate(transposeOp.permutation().getResults())) {
1485       int sourcePos = en.index();
1486       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1487       targetMemRef.setSize(rewriter, loc, targetPos,
1488                            viewMemRef.size(rewriter, loc, sourcePos));
1489       targetMemRef.setStride(rewriter, loc, targetPos,
1490                              viewMemRef.stride(rewriter, loc, sourcePos));
1491     }
1492 
1493     rewriter.replaceOp(transposeOp, {targetMemRef});
1494     return success();
1495   }
1496 };
1497 
1498 /// Conversion pattern that transforms an op into:
1499 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1500 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1501 ///      and stride.
1502 /// The view op is replaced by the descriptor.
1503 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1504   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1505 
1506   // Build and return the value for the idx^th shape dimension, either by
1507   // returning the constant shape dimension or counting the proper dynamic size.
1508   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1509                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1510                 unsigned idx) const {
1511     assert(idx < shape.size());
1512     if (!ShapedType::isDynamic(shape[idx]))
1513       return createIndexConstant(rewriter, loc, shape[idx]);
1514     // Count the number of dynamic dims in range [0, idx]
1515     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1516       return ShapedType::isDynamic(v);
1517     });
1518     return dynamicSizes[nDynamic];
1519   }
1520 
1521   // Build and return the idx^th stride, either by returning the constant stride
1522   // or by computing the dynamic stride from the current `runningStride` and
1523   // `nextSize`. The caller should keep a running stride and update it with the
1524   // result returned by this function.
1525   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1526                   ArrayRef<int64_t> strides, Value nextSize,
1527                   Value runningStride, unsigned idx) const {
1528     assert(idx < strides.size());
1529     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1530       return createIndexConstant(rewriter, loc, strides[idx]);
1531     if (nextSize)
1532       return runningStride
1533                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1534                  : nextSize;
1535     assert(!runningStride);
1536     return createIndexConstant(rewriter, loc, 1);
1537   }
1538 
1539   LogicalResult
1540   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1541                   ConversionPatternRewriter &rewriter) const override {
1542     auto loc = viewOp.getLoc();
1543 
1544     auto viewMemRefType = viewOp.getType();
1545     auto targetElementTy =
1546         typeConverter->convertType(viewMemRefType.getElementType());
1547     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1548     if (!targetDescTy || !targetElementTy ||
1549         !LLVM::isCompatibleType(targetElementTy) ||
1550         !LLVM::isCompatibleType(targetDescTy))
1551       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1552              failure();
1553 
1554     int64_t offset;
1555     SmallVector<int64_t, 4> strides;
1556     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1557     if (failed(successStrides))
1558       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1559     assert(offset == 0 && "expected offset to be 0");
1560 
1561     // Create the descriptor.
1562     MemRefDescriptor sourceMemRef(adaptor.source());
1563     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1564 
1565     // Field 1: Copy the allocated pointer, used for malloc/free.
1566     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1567     auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
1568     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1569         loc,
1570         LLVM::LLVMPointerType::get(targetElementTy,
1571                                    srcMemRefType.getMemorySpaceAsInt()),
1572         allocatedPtr);
1573     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1574 
1575     // Field 2: Copy the actual aligned pointer to payload.
1576     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1577     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
1578                                               alignedPtr, adaptor.byte_shift());
1579     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1580         loc,
1581         LLVM::LLVMPointerType::get(targetElementTy,
1582                                    srcMemRefType.getMemorySpaceAsInt()),
1583         alignedPtr);
1584     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1585 
1586     // Field 3: The offset in the resulting type must be 0. This is because of
1587     // the type change: an offset on srcType* may not be expressible as an
1588     // offset on dstType*.
1589     targetMemRef.setOffset(rewriter, loc,
1590                            createIndexConstant(rewriter, loc, offset));
1591 
1592     // Early exit for 0-D corner case.
1593     if (viewMemRefType.getRank() == 0)
1594       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1595 
1596     // Fields 4 and 5: Update sizes and strides.
1597     if (strides.back() != 1)
1598       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1599              failure();
1600     Value stride = nullptr, nextSize = nullptr;
1601     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1602       // Update size.
1603       Value size =
1604           getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
1605       targetMemRef.setSize(rewriter, loc, i, size);
1606       // Update stride.
1607       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1608       targetMemRef.setStride(rewriter, loc, i, stride);
1609       nextSize = size;
1610     }
1611 
1612     rewriter.replaceOp(viewOp, {targetMemRef});
1613     return success();
1614   }
1615 };
1616 
1617 //===----------------------------------------------------------------------===//
1618 // AtomicRMWOpLowering
1619 //===----------------------------------------------------------------------===//
1620 
1621 /// Try to match the kind of a std.atomic_rmw to determine whether to use a
1622 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1623 static Optional<LLVM::AtomicBinOp>
1624 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1625   switch (atomicOp.kind()) {
1626   case arith::AtomicRMWKind::addf:
1627     return LLVM::AtomicBinOp::fadd;
1628   case arith::AtomicRMWKind::addi:
1629     return LLVM::AtomicBinOp::add;
1630   case arith::AtomicRMWKind::assign:
1631     return LLVM::AtomicBinOp::xchg;
1632   case arith::AtomicRMWKind::maxs:
1633     return LLVM::AtomicBinOp::max;
1634   case arith::AtomicRMWKind::maxu:
1635     return LLVM::AtomicBinOp::umax;
1636   case arith::AtomicRMWKind::mins:
1637     return LLVM::AtomicBinOp::min;
1638   case arith::AtomicRMWKind::minu:
1639     return LLVM::AtomicBinOp::umin;
1640   case arith::AtomicRMWKind::ori:
1641     return LLVM::AtomicBinOp::_or;
1642   case arith::AtomicRMWKind::andi:
1643     return LLVM::AtomicBinOp::_and;
1644   default:
1645     return llvm::None;
1646   }
1647   llvm_unreachable("Invalid AtomicRMWKind");
1648 }
1649 
1650 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1651   using Base::Base;
1652 
1653   LogicalResult
1654   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1655                   ConversionPatternRewriter &rewriter) const override {
1656     if (failed(match(atomicOp)))
1657       return failure();
1658     auto maybeKind = matchSimpleAtomicOp(atomicOp);
1659     if (!maybeKind)
1660       return failure();
1661     auto resultType = adaptor.value().getType();
1662     auto memRefType = atomicOp.getMemRefType();
1663     auto dataPtr =
1664         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
1665                              adaptor.indices(), rewriter);
1666     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1667         atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
1668         LLVM::AtomicOrdering::acq_rel);
1669     return success();
1670   }
1671 };
1672 
1673 } // namespace
1674 
1675 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
1676                                                   RewritePatternSet &patterns) {
1677   // clang-format off
1678   patterns.add<
1679       AllocaOpLowering,
1680       AllocaScopeOpLowering,
1681       AtomicRMWOpLowering,
1682       AssumeAlignmentOpLowering,
1683       DimOpLowering,
1684       GlobalMemrefOpLowering,
1685       GetGlobalMemrefOpLowering,
1686       LoadOpLowering,
1687       MemRefCastOpLowering,
1688       MemRefCopyOpLowering,
1689       MemRefReinterpretCastOpLowering,
1690       MemRefReshapeOpLowering,
1691       PrefetchOpLowering,
1692       RankOpLowering,
1693       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1694       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1695       StoreOpLowering,
1696       SubViewOpLowering,
1697       TransposeOpLowering,
1698       ViewOpLowering>(converter);
1699   // clang-format on
1700   auto allocLowering = converter.getOptions().allocLowering;
1701   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1702     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1703   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1704     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1705 }
1706 
1707 namespace {
1708 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
1709   MemRefToLLVMPass() = default;
1710 
1711   void runOnOperation() override {
1712     Operation *op = getOperation();
1713     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1714     LowerToLLVMOptions options(&getContext(),
1715                                dataLayoutAnalysis.getAtOrAbove(op));
1716     options.allocLowering =
1717         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1718                          : LowerToLLVMOptions::AllocLowering::Malloc);
1719     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1720       options.overrideIndexBitwidth(indexBitwidth);
1721 
1722     LLVMTypeConverter typeConverter(&getContext(), options,
1723                                     &dataLayoutAnalysis);
1724     RewritePatternSet patterns(&getContext());
1725     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
1726     LLVMConversionTarget target(getContext());
1727     target.addLegalOp<FuncOp>();
1728     if (failed(applyPartialConversion(op, target, std::move(patterns))))
1729       signalPassFailure();
1730   }
1731 };
1732 } // namespace
1733 
1734 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
1735   return std::make_unique<MemRefToLLVMPass>();
1736 }
1737