1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
27   AllocOpLowering(LLVMTypeConverter &converter)
28       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
29                                 converter) {}
30 
31   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
32                                           Location loc, Value sizeBytes,
33                                           Operation *op) const override {
34     // Heap allocations.
35     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
36     MemRefType memRefType = allocOp.getType();
37 
38     Value alignment;
39     if (auto alignmentAttr = allocOp.alignment()) {
40       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
41     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
42       // In the case where no alignment is specified, we may want to override
43       // `malloc's` behavior. `malloc` typically aligns at the size of the
44       // biggest scalar on a target HW. For non-scalars, use the natural
45       // alignment of the LLVM type given by the LLVM DataLayout.
46       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
47     }
48 
49     if (alignment) {
50       // Adjust the allocation size to consider alignment.
51       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
52     }
53 
54     // Allocate the underlying buffer and store a pointer to it in the MemRef
55     // descriptor.
56     Type elementPtrType = this->getElementPtrType(memRefType);
57     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
58         allocOp->getParentOfType<ModuleOp>(), getIndexType());
59     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
60                                   getVoidPtrType());
61     Value allocatedPtr =
62         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
63 
64     Value alignedPtr = allocatedPtr;
65     if (alignment) {
66       // Compute the aligned type pointer.
67       Value allocatedInt =
68           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
69       Value alignmentInt =
70           createAligned(rewriter, loc, allocatedInt, alignment);
71       alignedPtr =
72           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
73     }
74 
75     return std::make_tuple(allocatedPtr, alignedPtr);
76   }
77 };
78 
79 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
80   AlignedAllocOpLowering(LLVMTypeConverter &converter)
81       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
82                                 converter) {}
83 
84   /// Returns the memref's element size in bytes using the data layout active at
85   /// `op`.
86   // TODO: there are other places where this is used. Expose publicly?
87   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
88     const DataLayout *layout = &defaultLayout;
89     if (const DataLayoutAnalysis *analysis =
90             getTypeConverter()->getDataLayoutAnalysis()) {
91       layout = &analysis->getAbove(op);
92     }
93     Type elementType = memRefType.getElementType();
94     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
95       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
96                                                          *layout);
97     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
98       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
99           memRefElementType, *layout);
100     return layout->getTypeSize(elementType);
101   }
102 
103   /// Returns true if the memref size in bytes is known to be a multiple of
104   /// factor assuming the data layout active at `op`.
105   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
106                               Operation *op) const {
107     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
108     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
109       if (type.isDynamic(type.getDimSize(i)))
110         continue;
111       sizeDivisor = sizeDivisor * type.getDimSize(i);
112     }
113     return sizeDivisor % factor == 0;
114   }
115 
116   /// Returns the alignment to be used for the allocation call itself.
117   /// aligned_alloc requires the allocation size to be a power of two, and the
118   /// allocation size to be a multiple of alignment,
119   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
120     if (Optional<uint64_t> alignment = allocOp.alignment())
121       return *alignment;
122 
123     // Whenever we don't have alignment set, we will use an alignment
124     // consistent with the element type; since the allocation size has to be a
125     // power of two, we will bump to the next power of two if it already isn't.
126     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
127     return std::max(kMinAlignedAllocAlignment,
128                     llvm::PowerOf2Ceil(eltSizeBytes));
129   }
130 
131   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
132                                           Location loc, Value sizeBytes,
133                                           Operation *op) const override {
134     // Heap allocations.
135     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
136     MemRefType memRefType = allocOp.getType();
137     int64_t alignment = getAllocationAlignment(allocOp);
138     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
139 
140     // aligned_alloc requires size to be a multiple of alignment; we will pad
141     // the size to the next multiple if necessary.
142     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
143       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
144 
145     Type elementPtrType = this->getElementPtrType(memRefType);
146     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
147         allocOp->getParentOfType<ModuleOp>(), getIndexType());
148     auto results =
149         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
150                        getVoidPtrType());
151     Value allocatedPtr =
152         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
153 
154     return std::make_tuple(allocatedPtr, allocatedPtr);
155   }
156 
157   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
158   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
159 
160   /// Default layout to use in absence of the corresponding analysis.
161   DataLayout defaultLayout;
162 };
163 
164 // Out of line definition, required till C++17.
165 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
166 
167 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
168   AllocaOpLowering(LLVMTypeConverter &converter)
169       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
170                                 converter) {}
171 
172   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
173   /// is set to null for stack allocations. `accessAlignment` is set if
174   /// alignment is needed post allocation (for eg. in conjunction with malloc).
175   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
176                                           Location loc, Value sizeBytes,
177                                           Operation *op) const override {
178 
179     // With alloca, one gets a pointer to the element type right away.
180     // For stack allocations.
181     auto allocaOp = cast<memref::AllocaOp>(op);
182     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
183 
184     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
185         loc, elementPtrType, sizeBytes,
186         allocaOp.alignment() ? *allocaOp.alignment() : 0);
187 
188     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
189   }
190 };
191 
192 struct AllocaScopeOpLowering
193     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
194   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
195 
196   LogicalResult
197   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, ArrayRef<Value> operands,
198                   ConversionPatternRewriter &rewriter) const override {
199     OpBuilder::InsertionGuard guard(rewriter);
200     Location loc = allocaScopeOp.getLoc();
201 
202     // Split the current block before the AllocaScopeOp to create the inlining
203     // point.
204     auto *currentBlock = rewriter.getInsertionBlock();
205     auto *remainingOpsBlock =
206         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
207     Block *continueBlock;
208     if (allocaScopeOp.getNumResults() == 0) {
209       continueBlock = remainingOpsBlock;
210     } else {
211       continueBlock = rewriter.createBlock(remainingOpsBlock,
212                                            allocaScopeOp.getResultTypes());
213       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
214     }
215 
216     // Inline body region.
217     Block *beforeBody = &allocaScopeOp.bodyRegion().front();
218     Block *afterBody = &allocaScopeOp.bodyRegion().back();
219     rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
220 
221     // Save stack and then branch into the body of the region.
222     rewriter.setInsertionPointToEnd(currentBlock);
223     auto stackSaveOp =
224         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
225     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
226 
227     // Replace the alloca_scope return with a branch that jumps out of the body.
228     // Stack restore before leaving the body region.
229     rewriter.setInsertionPointToEnd(afterBody);
230     auto returnOp =
231         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
232     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
233         returnOp, returnOp.results(), continueBlock);
234 
235     // Insert stack restore before jumping out the body of the region.
236     rewriter.setInsertionPoint(branchOp);
237     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
238 
239     // Replace the op with values return from the body region.
240     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
241 
242     return success();
243   }
244 };
245 
246 struct AssumeAlignmentOpLowering
247     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
248   using ConvertOpToLLVMPattern<
249       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
250 
251   LogicalResult
252   matchAndRewrite(memref::AssumeAlignmentOp op, ArrayRef<Value> operands,
253                   ConversionPatternRewriter &rewriter) const override {
254     memref::AssumeAlignmentOp::Adaptor transformed(operands);
255     Value memref = transformed.memref();
256     unsigned alignment = op.alignment();
257     auto loc = op.getLoc();
258 
259     MemRefDescriptor memRefDescriptor(memref);
260     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
261 
262     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
263     // the asserted memref.alignedPtr isn't used anywhere else, as the real
264     // users like load/store/views always re-extract memref.alignedPtr as they
265     // get lowered.
266     //
267     // This relies on LLVM's CSE optimization (potentially after SROA), since
268     // after CSE all memref.alignedPtr instances get de-duplicated into the same
269     // pointer SSA value.
270     auto intPtrType =
271         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
272     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
273     Value mask =
274         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
275     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
276     rewriter.create<LLVM::AssumeOp>(
277         loc, rewriter.create<LLVM::ICmpOp>(
278                  loc, LLVM::ICmpPredicate::eq,
279                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
280 
281     rewriter.eraseOp(op);
282     return success();
283   }
284 };
285 
286 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
287 // The memref descriptor being an SSA value, there is no need to clean it up
288 // in any way.
289 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
290   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
291 
292   explicit DeallocOpLowering(LLVMTypeConverter &converter)
293       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
294 
295   LogicalResult
296   matchAndRewrite(memref::DeallocOp op, ArrayRef<Value> operands,
297                   ConversionPatternRewriter &rewriter) const override {
298     assert(operands.size() == 1 && "dealloc takes one operand");
299     memref::DeallocOp::Adaptor transformed(operands);
300 
301     // Insert the `free` declaration if it is not already present.
302     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
303     MemRefDescriptor memref(transformed.memref());
304     Value casted = rewriter.create<LLVM::BitcastOp>(
305         op.getLoc(), getVoidPtrType(),
306         memref.allocatedPtr(rewriter, op.getLoc()));
307     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
308         op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted);
309     return success();
310   }
311 };
312 
313 // A `dim` is converted to a constant for static sizes and to an access to the
314 // size stored in the memref descriptor for dynamic sizes.
315 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
316   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
317 
318   LogicalResult
319   matchAndRewrite(memref::DimOp dimOp, ArrayRef<Value> operands,
320                   ConversionPatternRewriter &rewriter) const override {
321     Type operandType = dimOp.source().getType();
322     if (operandType.isa<UnrankedMemRefType>()) {
323       rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef(
324                                     operandType, dimOp, operands, rewriter)});
325 
326       return success();
327     }
328     if (operandType.isa<MemRefType>()) {
329       rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef(
330                                     operandType, dimOp, operands, rewriter)});
331       return success();
332     }
333     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
334   }
335 
336 private:
337   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
338                                     ArrayRef<Value> operands,
339                                     ConversionPatternRewriter &rewriter) const {
340     Location loc = dimOp.getLoc();
341     memref::DimOp::Adaptor transformed(operands);
342 
343     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
344     auto scalarMemRefType =
345         MemRefType::get({}, unrankedMemRefType.getElementType());
346     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
347 
348     // Extract pointer to the underlying ranked descriptor and bitcast it to a
349     // memref<element_type> descriptor pointer to minimize the number of GEP
350     // operations.
351     UnrankedMemRefDescriptor unrankedDesc(transformed.source());
352     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
353     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
354         loc,
355         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
356                                    addressSpace),
357         underlyingRankedDesc);
358 
359     // Get pointer to offset field of memref<element_type> descriptor.
360     Type indexPtrTy = LLVM::LLVMPointerType::get(
361         getTypeConverter()->getIndexType(), addressSpace);
362     Value two = rewriter.create<LLVM::ConstantOp>(
363         loc, typeConverter->convertType(rewriter.getI32Type()),
364         rewriter.getI32IntegerAttr(2));
365     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
366         loc, indexPtrTy, scalarMemRefDescPtr,
367         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
368 
369     // The size value that we have to extract can be obtained using GEPop with
370     // `dimOp.index() + 1` index argument.
371     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
372         loc, createIndexConstant(rewriter, loc, 1), transformed.index());
373     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
374                                                  ValueRange({idxPlusOne}));
375     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
376   }
377 
378   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
379     if (Optional<int64_t> idx = dimOp.getConstantIndex())
380       return idx;
381 
382     if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>())
383       return constantOp.value().cast<IntegerAttr>().getValue().getSExtValue();
384 
385     return llvm::None;
386   }
387 
388   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
389                                   ArrayRef<Value> operands,
390                                   ConversionPatternRewriter &rewriter) const {
391     Location loc = dimOp.getLoc();
392     memref::DimOp::Adaptor transformed(operands);
393     // Take advantage if index is constant.
394     MemRefType memRefType = operandType.cast<MemRefType>();
395     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
396       int64_t i = index.getValue();
397       if (memRefType.isDynamicDim(i)) {
398         // extract dynamic size from the memref descriptor.
399         MemRefDescriptor descriptor(transformed.source());
400         return descriptor.size(rewriter, loc, i);
401       }
402       // Use constant for static size.
403       int64_t dimSize = memRefType.getDimSize(i);
404       return createIndexConstant(rewriter, loc, dimSize);
405     }
406     Value index = transformed.index();
407     int64_t rank = memRefType.getRank();
408     MemRefDescriptor memrefDescriptor(transformed.source());
409     return memrefDescriptor.size(rewriter, loc, index, rank);
410   }
411 };
412 
413 /// Returns the LLVM type of the global variable given the memref type `type`.
414 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
415                                           LLVMTypeConverter &typeConverter) {
416   // LLVM type for a global memref will be a multi-dimension array. For
417   // declarations or uninitialized global memrefs, we can potentially flatten
418   // this to a 1D array. However, for memref.global's with an initial value,
419   // we do not intend to flatten the ElementsAttribute when going from std ->
420   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
421   Type elementType = typeConverter.convertType(type.getElementType());
422   Type arrayTy = elementType;
423   // Shape has the outermost dim at index 0, so need to walk it backwards
424   for (int64_t dim : llvm::reverse(type.getShape()))
425     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
426   return arrayTy;
427 }
428 
429 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
430 struct GlobalMemrefOpLowering
431     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
432   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
433 
434   LogicalResult
435   matchAndRewrite(memref::GlobalOp global, ArrayRef<Value> operands,
436                   ConversionPatternRewriter &rewriter) const override {
437     MemRefType type = global.type().cast<MemRefType>();
438     if (!isConvertibleAndHasIdentityMaps(type))
439       return failure();
440 
441     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
442 
443     LLVM::Linkage linkage =
444         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
445 
446     Attribute initialValue = nullptr;
447     if (!global.isExternal() && !global.isUninitialized()) {
448       auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
449       initialValue = elementsAttr;
450 
451       // For scalar memrefs, the global variable created is of the element type,
452       // so unpack the elements attribute to extract the value.
453       if (type.getRank() == 0)
454         initialValue = elementsAttr.getValue({});
455     }
456 
457     rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
458         global, arrayTy, global.constant(), linkage, global.sym_name(),
459         initialValue, /*alignment=*/0, type.getMemorySpaceAsInt());
460     return success();
461   }
462 };
463 
464 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
465 /// the first element stashed into the descriptor. This reuses
466 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
467 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
468   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
469       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
470                                 converter) {}
471 
472   /// Buffer "allocation" for memref.get_global op is getting the address of
473   /// the global variable referenced.
474   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
475                                           Location loc, Value sizeBytes,
476                                           Operation *op) const override {
477     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
478     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
479     unsigned memSpace = type.getMemorySpaceAsInt();
480 
481     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
482     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
483         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
484 
485     // Get the address of the first element in the array by creating a GEP with
486     // the address of the GV as the base, and (rank + 1) number of 0 indices.
487     Type elementType = typeConverter->convertType(type.getElementType());
488     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
489 
490     SmallVector<Value, 4> operands = {addressOf};
491     operands.insert(operands.end(), type.getRank() + 1,
492                     createIndexConstant(rewriter, loc, 0));
493     auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands);
494 
495     // We do not expect the memref obtained using `memref.get_global` to be
496     // ever deallocated. Set the allocated pointer to be known bad value to
497     // help debug if that ever happens.
498     auto intPtrType = getIntPtrType(memSpace);
499     Value deadBeefConst =
500         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
501     auto deadBeefPtr =
502         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
503 
504     // Both allocated and aligned pointers are same. We could potentially stash
505     // a nullptr for the allocated pointer since we do not expect any dealloc.
506     return std::make_tuple(deadBeefPtr, gep);
507   }
508 };
509 
510 // Common base for load and store operations on MemRefs. Restricts the match
511 // to supported MemRef types. Provides functionality to emit code accessing a
512 // specific element of the underlying data buffer.
513 template <typename Derived>
514 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
515   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
516   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
517   using Base = LoadStoreOpLowering<Derived>;
518 
519   LogicalResult match(Derived op) const override {
520     MemRefType type = op.getMemRefType();
521     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
522   }
523 };
524 
525 // Load operation is lowered to obtaining a pointer to the indexed element
526 // and loading it.
527 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
528   using Base::Base;
529 
530   LogicalResult
531   matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
532                   ConversionPatternRewriter &rewriter) const override {
533     memref::LoadOp::Adaptor transformed(operands);
534     auto type = loadOp.getMemRefType();
535 
536     Value dataPtr =
537         getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(),
538                              transformed.indices(), rewriter);
539     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
540     return success();
541   }
542 };
543 
544 // Store operation is lowered to obtaining a pointer to the indexed element,
545 // and storing the given value to it.
546 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
547   using Base::Base;
548 
549   LogicalResult
550   matchAndRewrite(memref::StoreOp op, ArrayRef<Value> operands,
551                   ConversionPatternRewriter &rewriter) const override {
552     auto type = op.getMemRefType();
553     memref::StoreOp::Adaptor transformed(operands);
554 
555     Value dataPtr =
556         getStridedElementPtr(op.getLoc(), type, transformed.memref(),
557                              transformed.indices(), rewriter);
558     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
559                                                dataPtr);
560     return success();
561   }
562 };
563 
564 // The prefetch operation is lowered in a way similar to the load operation
565 // except that the llvm.prefetch operation is used for replacement.
566 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
567   using Base::Base;
568 
569   LogicalResult
570   matchAndRewrite(memref::PrefetchOp prefetchOp, ArrayRef<Value> operands,
571                   ConversionPatternRewriter &rewriter) const override {
572     memref::PrefetchOp::Adaptor transformed(operands);
573     auto type = prefetchOp.getMemRefType();
574     auto loc = prefetchOp.getLoc();
575 
576     Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(),
577                                          transformed.indices(), rewriter);
578 
579     // Replace with llvm.prefetch.
580     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
581     auto isWrite = rewriter.create<LLVM::ConstantOp>(
582         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
583     auto localityHint = rewriter.create<LLVM::ConstantOp>(
584         loc, llvmI32Type,
585         rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
586     auto isData = rewriter.create<LLVM::ConstantOp>(
587         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
588 
589     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
590                                                 localityHint, isData);
591     return success();
592   }
593 };
594 
595 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
596   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
597 
598   LogicalResult match(memref::CastOp memRefCastOp) const override {
599     Type srcType = memRefCastOp.getOperand().getType();
600     Type dstType = memRefCastOp.getType();
601 
602     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
603     // used for type erasure. For now they must preserve underlying element type
604     // and require source and result type to have the same rank. Therefore,
605     // perform a sanity check that the underlying structs are the same. Once op
606     // semantics are relaxed we can revisit.
607     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
608       return success(typeConverter->convertType(srcType) ==
609                      typeConverter->convertType(dstType));
610 
611     // At least one of the operands is unranked type
612     assert(srcType.isa<UnrankedMemRefType>() ||
613            dstType.isa<UnrankedMemRefType>());
614 
615     // Unranked to unranked cast is disallowed
616     return !(srcType.isa<UnrankedMemRefType>() &&
617              dstType.isa<UnrankedMemRefType>())
618                ? success()
619                : failure();
620   }
621 
622   void rewrite(memref::CastOp memRefCastOp, ArrayRef<Value> operands,
623                ConversionPatternRewriter &rewriter) const override {
624     memref::CastOp::Adaptor transformed(operands);
625 
626     auto srcType = memRefCastOp.getOperand().getType();
627     auto dstType = memRefCastOp.getType();
628     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
629     auto loc = memRefCastOp.getLoc();
630 
631     // For ranked/ranked case, just keep the original descriptor.
632     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
633       return rewriter.replaceOp(memRefCastOp, {transformed.source()});
634 
635     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
636       // Casting ranked to unranked memref type
637       // Set the rank in the destination from the memref type
638       // Allocate space on the stack and copy the src memref descriptor
639       // Set the ptr in the destination to the stack space
640       auto srcMemRefType = srcType.cast<MemRefType>();
641       int64_t rank = srcMemRefType.getRank();
642       // ptr = AllocaOp sizeof(MemRefDescriptor)
643       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
644           loc, transformed.source(), rewriter);
645       // voidptr = BitCastOp srcType* to void*
646       auto voidPtr =
647           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
648               .getResult();
649       // rank = ConstantOp srcRank
650       auto rankVal = rewriter.create<LLVM::ConstantOp>(
651           loc, typeConverter->convertType(rewriter.getIntegerType(64)),
652           rewriter.getI64IntegerAttr(rank));
653       // undef = UndefOp
654       UnrankedMemRefDescriptor memRefDesc =
655           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
656       // d1 = InsertValueOp undef, rank, 0
657       memRefDesc.setRank(rewriter, loc, rankVal);
658       // d2 = InsertValueOp d1, voidptr, 1
659       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
660       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
661 
662     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
663       // Casting from unranked type to ranked.
664       // The operation is assumed to be doing a correct cast. If the destination
665       // type mismatches the unranked the type, it is undefined behavior.
666       UnrankedMemRefDescriptor memRefDesc(transformed.source());
667       // ptr = ExtractValueOp src, 1
668       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
669       // castPtr = BitCastOp i8* to structTy*
670       auto castPtr =
671           rewriter
672               .create<LLVM::BitcastOp>(
673                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
674               .getResult();
675       // struct = LoadOp castPtr
676       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
677       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
678     } else {
679       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
680     }
681   }
682 };
683 
684 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
685   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
686 
687   LogicalResult
688   matchAndRewrite(memref::CopyOp op, ArrayRef<Value> operands,
689                   ConversionPatternRewriter &rewriter) const override {
690     auto loc = op.getLoc();
691     memref::CopyOp::Adaptor adaptor(operands);
692     auto srcType = op.source().getType().cast<BaseMemRefType>();
693     auto targetType = op.target().getType().cast<BaseMemRefType>();
694 
695     // First make sure we have an unranked memref descriptor representation.
696     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
697       auto rank = rewriter.create<LLVM::ConstantOp>(
698           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
699       auto *typeConverter = getTypeConverter();
700       auto ptr =
701           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
702       auto voidPtr =
703           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
704               .getResult();
705       auto unrankedType =
706           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
707       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
708                                             unrankedType,
709                                             ValueRange{rank, voidPtr});
710     };
711 
712     Value unrankedSource = srcType.hasRank()
713                                ? makeUnranked(adaptor.source(), srcType)
714                                : adaptor.source();
715     Value unrankedTarget = targetType.hasRank()
716                                ? makeUnranked(adaptor.target(), targetType)
717                                : adaptor.target();
718 
719     // Now promote the unranked descriptors to the stack.
720     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
721                                                  rewriter.getIndexAttr(1));
722     auto promote = [&](Value desc) {
723       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
724       auto allocated =
725           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
726       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
727       return allocated;
728     };
729 
730     auto sourcePtr = promote(unrankedSource);
731     auto targetPtr = promote(unrankedTarget);
732 
733     auto elemSize = rewriter.create<LLVM::ConstantOp>(
734         loc, getIndexType(),
735         rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8));
736     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
737         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
738     rewriter.create<LLVM::CallOp>(loc, copyFn,
739                                   ValueRange{elemSize, sourcePtr, targetPtr});
740     rewriter.eraseOp(op);
741 
742     return success();
743   }
744 };
745 
746 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
747 /// memref type. In unranked case, the fields are extracted from the underlying
748 /// ranked descriptor.
749 static void extractPointersAndOffset(Location loc,
750                                      ConversionPatternRewriter &rewriter,
751                                      LLVMTypeConverter &typeConverter,
752                                      Value originalOperand,
753                                      Value convertedOperand,
754                                      Value *allocatedPtr, Value *alignedPtr,
755                                      Value *offset = nullptr) {
756   Type operandType = originalOperand.getType();
757   if (operandType.isa<MemRefType>()) {
758     MemRefDescriptor desc(convertedOperand);
759     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
760     *alignedPtr = desc.alignedPtr(rewriter, loc);
761     if (offset != nullptr)
762       *offset = desc.offset(rewriter, loc);
763     return;
764   }
765 
766   unsigned memorySpace =
767       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
768   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
769   Type llvmElementType = typeConverter.convertType(elementType);
770   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
771       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
772 
773   // Extract pointer to the underlying ranked memref descriptor and cast it to
774   // ElemType**.
775   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
776   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
777 
778   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
779       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
780   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
781       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
782   if (offset != nullptr) {
783     *offset = UnrankedMemRefDescriptor::offset(
784         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
785   }
786 }
787 
788 struct MemRefReinterpretCastOpLowering
789     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
790   using ConvertOpToLLVMPattern<
791       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
792 
793   LogicalResult
794   matchAndRewrite(memref::ReinterpretCastOp castOp, ArrayRef<Value> operands,
795                   ConversionPatternRewriter &rewriter) const override {
796     memref::ReinterpretCastOp::Adaptor adaptor(operands,
797                                                castOp->getAttrDictionary());
798     Type srcType = castOp.source().getType();
799 
800     Value descriptor;
801     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
802                                                adaptor, &descriptor)))
803       return failure();
804     rewriter.replaceOp(castOp, {descriptor});
805     return success();
806   }
807 
808 private:
809   LogicalResult convertSourceMemRefToDescriptor(
810       ConversionPatternRewriter &rewriter, Type srcType,
811       memref::ReinterpretCastOp castOp,
812       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
813     MemRefType targetMemRefType =
814         castOp.getResult().getType().cast<MemRefType>();
815     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
816                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
817     if (!llvmTargetDescriptorTy)
818       return failure();
819 
820     // Create descriptor.
821     Location loc = castOp.getLoc();
822     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
823 
824     // Set allocated and aligned pointers.
825     Value allocatedPtr, alignedPtr;
826     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
827                              castOp.source(), adaptor.source(), &allocatedPtr,
828                              &alignedPtr);
829     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
830     desc.setAlignedPtr(rewriter, loc, alignedPtr);
831 
832     // Set offset.
833     if (castOp.isDynamicOffset(0))
834       desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
835     else
836       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
837 
838     // Set sizes and strides.
839     unsigned dynSizeId = 0;
840     unsigned dynStrideId = 0;
841     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
842       if (castOp.isDynamicSize(i))
843         desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
844       else
845         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
846 
847       if (castOp.isDynamicStride(i))
848         desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
849       else
850         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
851     }
852     *descriptor = desc;
853     return success();
854   }
855 };
856 
857 struct MemRefReshapeOpLowering
858     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
859   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
860 
861   LogicalResult
862   matchAndRewrite(memref::ReshapeOp reshapeOp, ArrayRef<Value> operands,
863                   ConversionPatternRewriter &rewriter) const override {
864     auto *op = reshapeOp.getOperation();
865     memref::ReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
866     Type srcType = reshapeOp.source().getType();
867 
868     Value descriptor;
869     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
870                                                adaptor, &descriptor)))
871       return failure();
872     rewriter.replaceOp(op, {descriptor});
873     return success();
874   }
875 
876 private:
877   LogicalResult
878   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
879                                   Type srcType, memref::ReshapeOp reshapeOp,
880                                   memref::ReshapeOp::Adaptor adaptor,
881                                   Value *descriptor) const {
882     // Conversion for statically-known shape args is performed via
883     // `memref_reinterpret_cast`.
884     auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
885     if (shapeMemRefType.hasStaticShape())
886       return failure();
887 
888     // The shape is a rank-1 tensor with unknown length.
889     Location loc = reshapeOp.getLoc();
890     MemRefDescriptor shapeDesc(adaptor.shape());
891     Value resultRank = shapeDesc.size(rewriter, loc, 0);
892 
893     // Extract address space and element type.
894     auto targetType =
895         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
896     unsigned addressSpace = targetType.getMemorySpaceAsInt();
897     Type elementType = targetType.getElementType();
898 
899     // Create the unranked memref descriptor that holds the ranked one. The
900     // inner descriptor is allocated on stack.
901     auto targetDesc = UnrankedMemRefDescriptor::undef(
902         rewriter, loc, typeConverter->convertType(targetType));
903     targetDesc.setRank(rewriter, loc, resultRank);
904     SmallVector<Value, 4> sizes;
905     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
906                                            targetDesc, sizes);
907     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
908         loc, getVoidPtrType(), sizes.front(), llvm::None);
909     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
910 
911     // Extract pointers and offset from the source memref.
912     Value allocatedPtr, alignedPtr, offset;
913     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
914                              reshapeOp.source(), adaptor.source(),
915                              &allocatedPtr, &alignedPtr, &offset);
916 
917     // Set pointers and offset.
918     Type llvmElementType = typeConverter->convertType(elementType);
919     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
920         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
921     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
922                                               elementPtrPtrType, allocatedPtr);
923     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
924                                             underlyingDescPtr,
925                                             elementPtrPtrType, alignedPtr);
926     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
927                                         underlyingDescPtr, elementPtrPtrType,
928                                         offset);
929 
930     // Use the offset pointer as base for further addressing. Copy over the new
931     // shape and compute strides. For this, we create a loop from rank-1 to 0.
932     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
933         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
934         elementPtrPtrType);
935     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
936         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
937     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
938     Value oneIndex = createIndexConstant(rewriter, loc, 1);
939     Value resultRankMinusOne =
940         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
941 
942     Block *initBlock = rewriter.getInsertionBlock();
943     Type indexType = getTypeConverter()->getIndexType();
944     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
945 
946     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
947                                             {indexType, indexType});
948 
949     // Move the remaining initBlock ops to condBlock.
950     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
951     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
952 
953     rewriter.setInsertionPointToEnd(initBlock);
954     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
955                                 condBlock);
956     rewriter.setInsertionPointToStart(condBlock);
957     Value indexArg = condBlock->getArgument(0);
958     Value strideArg = condBlock->getArgument(1);
959 
960     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
961     Value pred = rewriter.create<LLVM::ICmpOp>(
962         loc, IntegerType::get(rewriter.getContext(), 1),
963         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
964 
965     Block *bodyBlock =
966         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
967     rewriter.setInsertionPointToStart(bodyBlock);
968 
969     // Copy size from shape to descriptor.
970     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
971     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
972         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
973     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
974     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
975                                       targetSizesBase, indexArg, size);
976 
977     // Write stride value and compute next one.
978     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
979                                         targetStridesBase, indexArg, strideArg);
980     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
981 
982     // Decrement loop counter and branch back.
983     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
984     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
985                                 condBlock);
986 
987     Block *remainder =
988         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
989 
990     // Hook up the cond exit to the remainder.
991     rewriter.setInsertionPointToEnd(condBlock);
992     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
993                                     llvm::None);
994 
995     // Reset position to beginning of new remainder block.
996     rewriter.setInsertionPointToStart(remainder);
997 
998     *descriptor = targetDesc;
999     return success();
1000   }
1001 };
1002 
1003 /// Conversion pattern that transforms a subview op into:
1004 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1005 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1006 ///      and stride.
1007 /// The subview op is replaced by the descriptor.
1008 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1009   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1010 
1011   LogicalResult
1012   matchAndRewrite(memref::SubViewOp subViewOp, ArrayRef<Value> operands,
1013                   ConversionPatternRewriter &rewriter) const override {
1014     auto loc = subViewOp.getLoc();
1015 
1016     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
1017     auto sourceElementTy =
1018         typeConverter->convertType(sourceMemRefType.getElementType());
1019 
1020     auto viewMemRefType = subViewOp.getType();
1021     auto inferredType = memref::SubViewOp::inferResultType(
1022                             subViewOp.getSourceType(),
1023                             extractFromI64ArrayAttr(subViewOp.static_offsets()),
1024                             extractFromI64ArrayAttr(subViewOp.static_sizes()),
1025                             extractFromI64ArrayAttr(subViewOp.static_strides()))
1026                             .cast<MemRefType>();
1027     auto targetElementTy =
1028         typeConverter->convertType(viewMemRefType.getElementType());
1029     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1030     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1031         !LLVM::isCompatibleType(sourceElementTy) ||
1032         !LLVM::isCompatibleType(targetElementTy) ||
1033         !LLVM::isCompatibleType(targetDescTy))
1034       return failure();
1035 
1036     // Extract the offset and strides from the type.
1037     int64_t offset;
1038     SmallVector<int64_t, 4> strides;
1039     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1040     if (failed(successStrides))
1041       return failure();
1042 
1043     // Create the descriptor.
1044     if (!LLVM::isCompatibleType(operands.front().getType()))
1045       return failure();
1046     MemRefDescriptor sourceMemRef(operands.front());
1047     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1048 
1049     // Copy the buffer pointer from the old descriptor to the new one.
1050     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1051     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1052         loc,
1053         LLVM::LLVMPointerType::get(targetElementTy,
1054                                    viewMemRefType.getMemorySpaceAsInt()),
1055         extracted);
1056     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1057 
1058     // Copy the aligned pointer from the old descriptor to the new one.
1059     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1060     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1061         loc,
1062         LLVM::LLVMPointerType::get(targetElementTy,
1063                                    viewMemRefType.getMemorySpaceAsInt()),
1064         extracted);
1065     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1066 
1067     auto shape = viewMemRefType.getShape();
1068     auto inferredShape = inferredType.getShape();
1069     size_t inferredShapeRank = inferredShape.size();
1070     size_t resultShapeRank = shape.size();
1071     llvm::SmallDenseSet<unsigned> unusedDims =
1072         computeRankReductionMask(inferredShape, shape).getValue();
1073 
1074     // Extract strides needed to compute offset.
1075     SmallVector<Value, 4> strideValues;
1076     strideValues.reserve(inferredShapeRank);
1077     for (unsigned i = 0; i < inferredShapeRank; ++i)
1078       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1079 
1080     // Offset.
1081     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1082     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1083       targetMemRef.setConstantOffset(rewriter, loc, offset);
1084     } else {
1085       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1086       // `inferredShapeRank` may be larger than the number of offset operands
1087       // because of trailing semantics. In this case, the offset is guaranteed
1088       // to be interpreted as 0 and we can just skip the extra dimensions.
1089       for (unsigned i = 0, e = std::min(inferredShapeRank,
1090                                         subViewOp.getMixedOffsets().size());
1091            i < e; ++i) {
1092         Value offset =
1093             // TODO: need OpFoldResult ODS adaptor to clean this up.
1094             subViewOp.isDynamicOffset(i)
1095                 ? operands[subViewOp.getIndexOfDynamicOffset(i)]
1096                 : rewriter.create<LLVM::ConstantOp>(
1097                       loc, llvmIndexType,
1098                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1099         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1100         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1101       }
1102       targetMemRef.setOffset(rewriter, loc, baseOffset);
1103     }
1104 
1105     // Update sizes and strides.
1106     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1107     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1108     assert(mixedSizes.size() == mixedStrides.size() &&
1109            "expected sizes and strides of equal length");
1110     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1111          i >= 0 && j >= 0; --i) {
1112       if (unusedDims.contains(i))
1113         continue;
1114 
1115       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1116       // In this case, the size is guaranteed to be interpreted as Dim and the
1117       // stride as 1.
1118       Value size, stride;
1119       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1120         // If the static size is available, use it directly. This is similar to
1121         // the folding of dim(constant-op) but removes the need for dim to be
1122         // aware of LLVM constants and for this pass to be aware of std
1123         // constants.
1124         int64_t staticSize =
1125             subViewOp.source().getType().cast<MemRefType>().getShape()[i];
1126         if (staticSize != ShapedType::kDynamicSize) {
1127           size = rewriter.create<LLVM::ConstantOp>(
1128               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1129         } else {
1130           Value pos = rewriter.create<LLVM::ConstantOp>(
1131               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1132           size = rewriter.create<LLVM::DialectCastOp>(
1133               loc, llvmIndexType,
1134               rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos));
1135         }
1136         stride = rewriter.create<LLVM::ConstantOp>(
1137             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1138       } else {
1139         // TODO: need OpFoldResult ODS adaptor to clean this up.
1140         size =
1141             subViewOp.isDynamicSize(i)
1142                 ? operands[subViewOp.getIndexOfDynamicSize(i)]
1143                 : rewriter.create<LLVM::ConstantOp>(
1144                       loc, llvmIndexType,
1145                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1146         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1147           stride = rewriter.create<LLVM::ConstantOp>(
1148               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1149         } else {
1150           stride = subViewOp.isDynamicStride(i)
1151                        ? operands[subViewOp.getIndexOfDynamicStride(i)]
1152                        : rewriter.create<LLVM::ConstantOp>(
1153                              loc, llvmIndexType,
1154                              rewriter.getI64IntegerAttr(
1155                                  subViewOp.getStaticStride(i)));
1156           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1157         }
1158       }
1159       targetMemRef.setSize(rewriter, loc, j, size);
1160       targetMemRef.setStride(rewriter, loc, j, stride);
1161       j--;
1162     }
1163 
1164     rewriter.replaceOp(subViewOp, {targetMemRef});
1165     return success();
1166   }
1167 };
1168 
1169 /// Conversion pattern that transforms a transpose op into:
1170 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1171 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1172 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1173 ///      and stride. Size and stride are permutations of the original values.
1174 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1175 /// The transpose op is replaced by the alloca'ed pointer.
1176 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1177 public:
1178   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1179 
1180   LogicalResult
1181   matchAndRewrite(memref::TransposeOp transposeOp, ArrayRef<Value> operands,
1182                   ConversionPatternRewriter &rewriter) const override {
1183     auto loc = transposeOp.getLoc();
1184     memref::TransposeOpAdaptor adaptor(operands);
1185     MemRefDescriptor viewMemRef(adaptor.in());
1186 
1187     // No permutation, early exit.
1188     if (transposeOp.permutation().isIdentity())
1189       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1190 
1191     auto targetMemRef = MemRefDescriptor::undef(
1192         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1193 
1194     // Copy the base and aligned pointers from the old descriptor to the new
1195     // one.
1196     targetMemRef.setAllocatedPtr(rewriter, loc,
1197                                  viewMemRef.allocatedPtr(rewriter, loc));
1198     targetMemRef.setAlignedPtr(rewriter, loc,
1199                                viewMemRef.alignedPtr(rewriter, loc));
1200 
1201     // Copy the offset pointer from the old descriptor to the new one.
1202     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1203 
1204     // Iterate over the dimensions and apply size/stride permutation.
1205     for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
1206       int sourcePos = en.index();
1207       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1208       targetMemRef.setSize(rewriter, loc, targetPos,
1209                            viewMemRef.size(rewriter, loc, sourcePos));
1210       targetMemRef.setStride(rewriter, loc, targetPos,
1211                              viewMemRef.stride(rewriter, loc, sourcePos));
1212     }
1213 
1214     rewriter.replaceOp(transposeOp, {targetMemRef});
1215     return success();
1216   }
1217 };
1218 
1219 /// Conversion pattern that transforms an op into:
1220 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1221 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1222 ///      and stride.
1223 /// The view op is replaced by the descriptor.
1224 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1225   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1226 
1227   // Build and return the value for the idx^th shape dimension, either by
1228   // returning the constant shape dimension or counting the proper dynamic size.
1229   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1230                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1231                 unsigned idx) const {
1232     assert(idx < shape.size());
1233     if (!ShapedType::isDynamic(shape[idx]))
1234       return createIndexConstant(rewriter, loc, shape[idx]);
1235     // Count the number of dynamic dims in range [0, idx]
1236     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1237       return ShapedType::isDynamic(v);
1238     });
1239     return dynamicSizes[nDynamic];
1240   }
1241 
1242   // Build and return the idx^th stride, either by returning the constant stride
1243   // or by computing the dynamic stride from the current `runningStride` and
1244   // `nextSize`. The caller should keep a running stride and update it with the
1245   // result returned by this function.
1246   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1247                   ArrayRef<int64_t> strides, Value nextSize,
1248                   Value runningStride, unsigned idx) const {
1249     assert(idx < strides.size());
1250     if (!MemRefType::isDynamicStrideOrOffset(strides[idx]))
1251       return createIndexConstant(rewriter, loc, strides[idx]);
1252     if (nextSize)
1253       return runningStride
1254                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1255                  : nextSize;
1256     assert(!runningStride);
1257     return createIndexConstant(rewriter, loc, 1);
1258   }
1259 
1260   LogicalResult
1261   matchAndRewrite(memref::ViewOp viewOp, ArrayRef<Value> operands,
1262                   ConversionPatternRewriter &rewriter) const override {
1263     auto loc = viewOp.getLoc();
1264     memref::ViewOpAdaptor adaptor(operands);
1265 
1266     auto viewMemRefType = viewOp.getType();
1267     auto targetElementTy =
1268         typeConverter->convertType(viewMemRefType.getElementType());
1269     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1270     if (!targetDescTy || !targetElementTy ||
1271         !LLVM::isCompatibleType(targetElementTy) ||
1272         !LLVM::isCompatibleType(targetDescTy))
1273       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1274              failure();
1275 
1276     int64_t offset;
1277     SmallVector<int64_t, 4> strides;
1278     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1279     if (failed(successStrides))
1280       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1281     assert(offset == 0 && "expected offset to be 0");
1282 
1283     // Create the descriptor.
1284     MemRefDescriptor sourceMemRef(adaptor.source());
1285     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1286 
1287     // Field 1: Copy the allocated pointer, used for malloc/free.
1288     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1289     auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
1290     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1291         loc,
1292         LLVM::LLVMPointerType::get(targetElementTy,
1293                                    srcMemRefType.getMemorySpaceAsInt()),
1294         allocatedPtr);
1295     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1296 
1297     // Field 2: Copy the actual aligned pointer to payload.
1298     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1299     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
1300                                               alignedPtr, adaptor.byte_shift());
1301     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1302         loc,
1303         LLVM::LLVMPointerType::get(targetElementTy,
1304                                    srcMemRefType.getMemorySpaceAsInt()),
1305         alignedPtr);
1306     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1307 
1308     // Field 3: The offset in the resulting type must be 0. This is because of
1309     // the type change: an offset on srcType* may not be expressible as an
1310     // offset on dstType*.
1311     targetMemRef.setOffset(rewriter, loc,
1312                            createIndexConstant(rewriter, loc, offset));
1313 
1314     // Early exit for 0-D corner case.
1315     if (viewMemRefType.getRank() == 0)
1316       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1317 
1318     // Fields 4 and 5: Update sizes and strides.
1319     if (strides.back() != 1)
1320       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1321              failure();
1322     Value stride = nullptr, nextSize = nullptr;
1323     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1324       // Update size.
1325       Value size =
1326           getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
1327       targetMemRef.setSize(rewriter, loc, i, size);
1328       // Update stride.
1329       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1330       targetMemRef.setStride(rewriter, loc, i, stride);
1331       nextSize = size;
1332     }
1333 
1334     rewriter.replaceOp(viewOp, {targetMemRef});
1335     return success();
1336   }
1337 };
1338 
1339 } // namespace
1340 
1341 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
1342                                                   RewritePatternSet &patterns) {
1343   // clang-format off
1344   patterns.add<
1345       AllocaOpLowering,
1346       AllocaScopeOpLowering,
1347       AssumeAlignmentOpLowering,
1348       DimOpLowering,
1349       DeallocOpLowering,
1350       GlobalMemrefOpLowering,
1351       GetGlobalMemrefOpLowering,
1352       LoadOpLowering,
1353       MemRefCastOpLowering,
1354       MemRefCopyOpLowering,
1355       MemRefReinterpretCastOpLowering,
1356       MemRefReshapeOpLowering,
1357       PrefetchOpLowering,
1358       StoreOpLowering,
1359       SubViewOpLowering,
1360       TransposeOpLowering,
1361       ViewOpLowering>(converter);
1362   // clang-format on
1363   auto allocLowering = converter.getOptions().allocLowering;
1364   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1365     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1366   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1367     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1368 }
1369 
1370 namespace {
1371 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
1372   MemRefToLLVMPass() = default;
1373 
1374   void runOnOperation() override {
1375     Operation *op = getOperation();
1376     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1377     LowerToLLVMOptions options(&getContext(),
1378                                dataLayoutAnalysis.getAtOrAbove(op));
1379     options.allocLowering =
1380         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1381                          : LowerToLLVMOptions::AllocLowering::Malloc);
1382     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1383       options.overrideIndexBitwidth(indexBitwidth);
1384 
1385     LLVMTypeConverter typeConverter(&getContext(), options,
1386                                     &dataLayoutAnalysis);
1387     RewritePatternSet patterns(&getContext());
1388     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
1389     LLVMConversionTarget target(getContext());
1390     target.addLegalOp<LLVM::DialectCastOp, FuncOp>();
1391     if (failed(applyPartialConversion(op, target, std::move(patterns))))
1392       signalPassFailure();
1393   }
1394 };
1395 } // namespace
1396 
1397 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
1398   return std::make_unique<MemRefToLLVMPass>();
1399 }
1400