1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "llvm/ADT/SmallBitVector.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 
27 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
28   AllocOpLowering(LLVMTypeConverter &converter)
29       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
30                                 converter) {}
31 
32   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
33                                           Location loc, Value sizeBytes,
34                                           Operation *op) const override {
35     // Heap allocations.
36     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
37     MemRefType memRefType = allocOp.getType();
38 
39     Value alignment;
40     if (auto alignmentAttr = allocOp.alignment()) {
41       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
42     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
43       // In the case where no alignment is specified, we may want to override
44       // `malloc's` behavior. `malloc` typically aligns at the size of the
45       // biggest scalar on a target HW. For non-scalars, use the natural
46       // alignment of the LLVM type given by the LLVM DataLayout.
47       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
48     }
49 
50     if (alignment) {
51       // Adjust the allocation size to consider alignment.
52       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
53     }
54 
55     // Allocate the underlying buffer and store a pointer to it in the MemRef
56     // descriptor.
57     Type elementPtrType = this->getElementPtrType(memRefType);
58     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
59         allocOp->getParentOfType<ModuleOp>(), getIndexType());
60     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
61                                   getVoidPtrType());
62     Value allocatedPtr =
63         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
64 
65     Value alignedPtr = allocatedPtr;
66     if (alignment) {
67       // Compute the aligned type pointer.
68       Value allocatedInt =
69           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
70       Value alignmentInt =
71           createAligned(rewriter, loc, allocatedInt, alignment);
72       alignedPtr =
73           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
74     }
75 
76     return std::make_tuple(allocatedPtr, alignedPtr);
77   }
78 };
79 
80 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
81   AlignedAllocOpLowering(LLVMTypeConverter &converter)
82       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
83                                 converter) {}
84 
85   /// Returns the memref's element size in bytes using the data layout active at
86   /// `op`.
87   // TODO: there are other places where this is used. Expose publicly?
88   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
89     const DataLayout *layout = &defaultLayout;
90     if (const DataLayoutAnalysis *analysis =
91             getTypeConverter()->getDataLayoutAnalysis()) {
92       layout = &analysis->getAbove(op);
93     }
94     Type elementType = memRefType.getElementType();
95     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
96       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
97                                                          *layout);
98     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
99       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
100           memRefElementType, *layout);
101     return layout->getTypeSize(elementType);
102   }
103 
104   /// Returns true if the memref size in bytes is known to be a multiple of
105   /// factor assuming the data layout active at `op`.
106   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
107                               Operation *op) const {
108     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
109     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
110       if (ShapedType::isDynamic(type.getDimSize(i)))
111         continue;
112       sizeDivisor = sizeDivisor * type.getDimSize(i);
113     }
114     return sizeDivisor % factor == 0;
115   }
116 
117   /// Returns the alignment to be used for the allocation call itself.
118   /// aligned_alloc requires the allocation size to be a power of two, and the
119   /// allocation size to be a multiple of alignment,
120   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
121     if (Optional<uint64_t> alignment = allocOp.alignment())
122       return *alignment;
123 
124     // Whenever we don't have alignment set, we will use an alignment
125     // consistent with the element type; since the allocation size has to be a
126     // power of two, we will bump to the next power of two if it already isn't.
127     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
128     return std::max(kMinAlignedAllocAlignment,
129                     llvm::PowerOf2Ceil(eltSizeBytes));
130   }
131 
132   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
133                                           Location loc, Value sizeBytes,
134                                           Operation *op) const override {
135     // Heap allocations.
136     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
137     MemRefType memRefType = allocOp.getType();
138     int64_t alignment = getAllocationAlignment(allocOp);
139     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
140 
141     // aligned_alloc requires size to be a multiple of alignment; we will pad
142     // the size to the next multiple if necessary.
143     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
144       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
145 
146     Type elementPtrType = this->getElementPtrType(memRefType);
147     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
148         allocOp->getParentOfType<ModuleOp>(), getIndexType());
149     auto results =
150         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
151                        getVoidPtrType());
152     Value allocatedPtr =
153         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
154 
155     return std::make_tuple(allocatedPtr, allocatedPtr);
156   }
157 
158   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
159   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
160 
161   /// Default layout to use in absence of the corresponding analysis.
162   DataLayout defaultLayout;
163 };
164 
165 // Out of line definition, required till C++17.
166 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
167 
168 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
169   AllocaOpLowering(LLVMTypeConverter &converter)
170       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
171                                 converter) {}
172 
173   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
174   /// is set to null for stack allocations. `accessAlignment` is set if
175   /// alignment is needed post allocation (for eg. in conjunction with malloc).
176   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
177                                           Location loc, Value sizeBytes,
178                                           Operation *op) const override {
179 
180     // With alloca, one gets a pointer to the element type right away.
181     // For stack allocations.
182     auto allocaOp = cast<memref::AllocaOp>(op);
183     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
184 
185     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
186         loc, elementPtrType, sizeBytes,
187         allocaOp.alignment() ? *allocaOp.alignment() : 0);
188 
189     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
190   }
191 };
192 
193 struct AllocaScopeOpLowering
194     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
195   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
196 
197   LogicalResult
198   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
199                   ConversionPatternRewriter &rewriter) const override {
200     OpBuilder::InsertionGuard guard(rewriter);
201     Location loc = allocaScopeOp.getLoc();
202 
203     // Split the current block before the AllocaScopeOp to create the inlining
204     // point.
205     auto *currentBlock = rewriter.getInsertionBlock();
206     auto *remainingOpsBlock =
207         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
208     Block *continueBlock;
209     if (allocaScopeOp.getNumResults() == 0) {
210       continueBlock = remainingOpsBlock;
211     } else {
212       continueBlock = rewriter.createBlock(
213           remainingOpsBlock, allocaScopeOp.getResultTypes(),
214           SmallVector<Location>(allocaScopeOp->getNumResults(),
215                                 allocaScopeOp.getLoc()));
216       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
217     }
218 
219     // Inline body region.
220     Block *beforeBody = &allocaScopeOp.bodyRegion().front();
221     Block *afterBody = &allocaScopeOp.bodyRegion().back();
222     rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
223 
224     // Save stack and then branch into the body of the region.
225     rewriter.setInsertionPointToEnd(currentBlock);
226     auto stackSaveOp =
227         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
228     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
229 
230     // Replace the alloca_scope return with a branch that jumps out of the body.
231     // Stack restore before leaving the body region.
232     rewriter.setInsertionPointToEnd(afterBody);
233     auto returnOp =
234         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
235     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
236         returnOp, returnOp.results(), continueBlock);
237 
238     // Insert stack restore before jumping out the body of the region.
239     rewriter.setInsertionPoint(branchOp);
240     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
241 
242     // Replace the op with values return from the body region.
243     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
244 
245     return success();
246   }
247 };
248 
249 struct AssumeAlignmentOpLowering
250     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
251   using ConvertOpToLLVMPattern<
252       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
253 
254   LogicalResult
255   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
256                   ConversionPatternRewriter &rewriter) const override {
257     Value memref = adaptor.memref();
258     unsigned alignment = op.alignment();
259     auto loc = op.getLoc();
260 
261     MemRefDescriptor memRefDescriptor(memref);
262     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
263 
264     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
265     // the asserted memref.alignedPtr isn't used anywhere else, as the real
266     // users like load/store/views always re-extract memref.alignedPtr as they
267     // get lowered.
268     //
269     // This relies on LLVM's CSE optimization (potentially after SROA), since
270     // after CSE all memref.alignedPtr instances get de-duplicated into the same
271     // pointer SSA value.
272     auto intPtrType =
273         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
274     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
275     Value mask =
276         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
277     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
278     rewriter.create<LLVM::AssumeOp>(
279         loc, rewriter.create<LLVM::ICmpOp>(
280                  loc, LLVM::ICmpPredicate::eq,
281                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
282 
283     rewriter.eraseOp(op);
284     return success();
285   }
286 };
287 
288 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
289 // The memref descriptor being an SSA value, there is no need to clean it up
290 // in any way.
291 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
292   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
293 
294   explicit DeallocOpLowering(LLVMTypeConverter &converter)
295       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
296 
297   LogicalResult
298   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
299                   ConversionPatternRewriter &rewriter) const override {
300     // Insert the `free` declaration if it is not already present.
301     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
302     MemRefDescriptor memref(adaptor.memref());
303     Value casted = rewriter.create<LLVM::BitcastOp>(
304         op.getLoc(), getVoidPtrType(),
305         memref.allocatedPtr(rewriter, op.getLoc()));
306     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
307         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
308     return success();
309   }
310 };
311 
312 // A `dim` is converted to a constant for static sizes and to an access to the
313 // size stored in the memref descriptor for dynamic sizes.
314 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
315   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
316 
317   LogicalResult
318   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
319                   ConversionPatternRewriter &rewriter) const override {
320     Type operandType = dimOp.source().getType();
321     if (operandType.isa<UnrankedMemRefType>()) {
322       rewriter.replaceOp(
323           dimOp, {extractSizeOfUnrankedMemRef(
324                      operandType, dimOp, adaptor.getOperands(), rewriter)});
325 
326       return success();
327     }
328     if (operandType.isa<MemRefType>()) {
329       rewriter.replaceOp(
330           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
331                                             adaptor.getOperands(), rewriter)});
332       return success();
333     }
334     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
335   }
336 
337 private:
338   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
339                                     OpAdaptor adaptor,
340                                     ConversionPatternRewriter &rewriter) const {
341     Location loc = dimOp.getLoc();
342 
343     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
344     auto scalarMemRefType =
345         MemRefType::get({}, unrankedMemRefType.getElementType());
346     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
347 
348     // Extract pointer to the underlying ranked descriptor and bitcast it to a
349     // memref<element_type> descriptor pointer to minimize the number of GEP
350     // operations.
351     UnrankedMemRefDescriptor unrankedDesc(adaptor.source());
352     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
353     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
354         loc,
355         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
356                                    addressSpace),
357         underlyingRankedDesc);
358 
359     // Get pointer to offset field of memref<element_type> descriptor.
360     Type indexPtrTy = LLVM::LLVMPointerType::get(
361         getTypeConverter()->getIndexType(), addressSpace);
362     Value two = rewriter.create<LLVM::ConstantOp>(
363         loc, typeConverter->convertType(rewriter.getI32Type()),
364         rewriter.getI32IntegerAttr(2));
365     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
366         loc, indexPtrTy, scalarMemRefDescPtr,
367         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
368 
369     // The size value that we have to extract can be obtained using GEPop with
370     // `dimOp.index() + 1` index argument.
371     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
372         loc, createIndexConstant(rewriter, loc, 1), adaptor.index());
373     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
374                                                  ValueRange({idxPlusOne}));
375     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
376   }
377 
378   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
379     if (Optional<int64_t> idx = dimOp.getConstantIndex())
380       return idx;
381 
382     if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>())
383       return constantOp.getValue()
384           .cast<IntegerAttr>()
385           .getValue()
386           .getSExtValue();
387 
388     return llvm::None;
389   }
390 
391   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
392                                   OpAdaptor adaptor,
393                                   ConversionPatternRewriter &rewriter) const {
394     Location loc = dimOp.getLoc();
395 
396     // Take advantage if index is constant.
397     MemRefType memRefType = operandType.cast<MemRefType>();
398     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
399       int64_t i = index.getValue();
400       if (memRefType.isDynamicDim(i)) {
401         // extract dynamic size from the memref descriptor.
402         MemRefDescriptor descriptor(adaptor.source());
403         return descriptor.size(rewriter, loc, i);
404       }
405       // Use constant for static size.
406       int64_t dimSize = memRefType.getDimSize(i);
407       return createIndexConstant(rewriter, loc, dimSize);
408     }
409     Value index = adaptor.index();
410     int64_t rank = memRefType.getRank();
411     MemRefDescriptor memrefDescriptor(adaptor.source());
412     return memrefDescriptor.size(rewriter, loc, index, rank);
413   }
414 };
415 
416 /// Common base for load and store operations on MemRefs. Restricts the match
417 /// to supported MemRef types. Provides functionality to emit code accessing a
418 /// specific element of the underlying data buffer.
419 template <typename Derived>
420 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
421   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
422   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
423   using Base = LoadStoreOpLowering<Derived>;
424 
425   LogicalResult match(Derived op) const override {
426     MemRefType type = op.getMemRefType();
427     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
428   }
429 };
430 
431 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
432 /// retried until it succeeds in atomically storing a new value into memory.
433 ///
434 ///      +---------------------------------+
435 ///      |   <code before the AtomicRMWOp> |
436 ///      |   <compute initial %loaded>     |
437 ///      |   cf.br loop(%loaded)              |
438 ///      +---------------------------------+
439 ///             |
440 ///  -------|   |
441 ///  |      v   v
442 ///  |   +--------------------------------+
443 ///  |   | loop(%loaded):                 |
444 ///  |   |   <body contents>              |
445 ///  |   |   %pair = cmpxchg              |
446 ///  |   |   %ok = %pair[0]               |
447 ///  |   |   %new = %pair[1]              |
448 ///  |   |   cf.cond_br %ok, end, loop(%new) |
449 ///  |   +--------------------------------+
450 ///  |          |        |
451 ///  |-----------        |
452 ///                      v
453 ///      +--------------------------------+
454 ///      | end:                           |
455 ///      |   <code after the AtomicRMWOp> |
456 ///      +--------------------------------+
457 ///
458 struct GenericAtomicRMWOpLowering
459     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
460   using Base::Base;
461 
462   LogicalResult
463   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
464                   ConversionPatternRewriter &rewriter) const override {
465     auto loc = atomicOp.getLoc();
466     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
467 
468     // Split the block into initial, loop, and ending parts.
469     auto *initBlock = rewriter.getInsertionBlock();
470     auto *loopBlock = rewriter.createBlock(
471         initBlock->getParent(), std::next(Region::iterator(initBlock)),
472         valueType, loc);
473     auto *endBlock = rewriter.createBlock(
474         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
475 
476     // Operations range to be moved to `endBlock`.
477     auto opsToMoveStart = atomicOp->getIterator();
478     auto opsToMoveEnd = initBlock->back().getIterator();
479 
480     // Compute the loaded value and branch to the loop block.
481     rewriter.setInsertionPointToEnd(initBlock);
482     auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
483     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
484                                         adaptor.indices(), rewriter);
485     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
486     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
487 
488     // Prepare the body of the loop block.
489     rewriter.setInsertionPointToStart(loopBlock);
490 
491     // Clone the GenericAtomicRMWOp region and extract the result.
492     auto loopArgument = loopBlock->getArgument(0);
493     BlockAndValueMapping mapping;
494     mapping.map(atomicOp.getCurrentValue(), loopArgument);
495     Block &entryBlock = atomicOp.body().front();
496     for (auto &nestedOp : entryBlock.without_terminator()) {
497       Operation *clone = rewriter.clone(nestedOp, mapping);
498       mapping.map(nestedOp.getResults(), clone->getResults());
499     }
500     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
501 
502     // Prepare the epilog of the loop block.
503     // Append the cmpxchg op to the end of the loop block.
504     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
505     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
506     auto boolType = IntegerType::get(rewriter.getContext(), 1);
507     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
508                                                      {valueType, boolType});
509     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
510         loc, pairType, dataPtr, loopArgument, result, successOrdering,
511         failureOrdering);
512     // Extract the %new_loaded and %ok values from the pair.
513     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
514         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
515     Value ok = rewriter.create<LLVM::ExtractValueOp>(
516         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
517 
518     // Conditionally branch to the end or back to the loop depending on %ok.
519     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
520                                     loopBlock, newLoaded);
521 
522     rewriter.setInsertionPointToEnd(endBlock);
523     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
524                  std::next(opsToMoveEnd), rewriter);
525 
526     // The 'result' of the atomic_rmw op is the newly loaded value.
527     rewriter.replaceOp(atomicOp, {newLoaded});
528 
529     return success();
530   }
531 
532 private:
533   // Clones a segment of ops [start, end) and erases the original.
534   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
535                     Block::iterator start, Block::iterator end,
536                     ConversionPatternRewriter &rewriter) const {
537     BlockAndValueMapping mapping;
538     mapping.map(oldResult, newResult);
539     SmallVector<Operation *, 2> opsToErase;
540     for (auto it = start; it != end; ++it) {
541       rewriter.clone(*it, mapping);
542       opsToErase.push_back(&*it);
543     }
544     for (auto *it : opsToErase)
545       rewriter.eraseOp(it);
546   }
547 };
548 
549 /// Returns the LLVM type of the global variable given the memref type `type`.
550 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
551                                           LLVMTypeConverter &typeConverter) {
552   // LLVM type for a global memref will be a multi-dimension array. For
553   // declarations or uninitialized global memrefs, we can potentially flatten
554   // this to a 1D array. However, for memref.global's with an initial value,
555   // we do not intend to flatten the ElementsAttribute when going from std ->
556   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
557   Type elementType = typeConverter.convertType(type.getElementType());
558   Type arrayTy = elementType;
559   // Shape has the outermost dim at index 0, so need to walk it backwards
560   for (int64_t dim : llvm::reverse(type.getShape()))
561     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
562   return arrayTy;
563 }
564 
565 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
566 struct GlobalMemrefOpLowering
567     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
568   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
569 
570   LogicalResult
571   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
572                   ConversionPatternRewriter &rewriter) const override {
573     MemRefType type = global.type();
574     if (!isConvertibleAndHasIdentityMaps(type))
575       return failure();
576 
577     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
578 
579     LLVM::Linkage linkage =
580         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
581 
582     Attribute initialValue = nullptr;
583     if (!global.isExternal() && !global.isUninitialized()) {
584       auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
585       initialValue = elementsAttr;
586 
587       // For scalar memrefs, the global variable created is of the element type,
588       // so unpack the elements attribute to extract the value.
589       if (type.getRank() == 0)
590         initialValue = elementsAttr.getSplatValue<Attribute>();
591     }
592 
593     uint64_t alignment = global.alignment().getValueOr(0);
594 
595     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
596         global, arrayTy, global.constant(), linkage, global.sym_name(),
597         initialValue, alignment, type.getMemorySpaceAsInt());
598     if (!global.isExternal() && global.isUninitialized()) {
599       Block *blk = new Block();
600       newGlobal.getInitializerRegion().push_back(blk);
601       rewriter.setInsertionPointToStart(blk);
602       Value undef[] = {
603           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
604       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
605     }
606     return success();
607   }
608 };
609 
610 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
611 /// the first element stashed into the descriptor. This reuses
612 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
613 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
614   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
615       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
616                                 converter) {}
617 
618   /// Buffer "allocation" for memref.get_global op is getting the address of
619   /// the global variable referenced.
620   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
621                                           Location loc, Value sizeBytes,
622                                           Operation *op) const override {
623     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
624     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
625     unsigned memSpace = type.getMemorySpaceAsInt();
626 
627     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
628     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
629         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
630 
631     // Get the address of the first element in the array by creating a GEP with
632     // the address of the GV as the base, and (rank + 1) number of 0 indices.
633     Type elementType = typeConverter->convertType(type.getElementType());
634     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
635 
636     SmallVector<Value> operands;
637     operands.insert(operands.end(), type.getRank() + 1,
638                     createIndexConstant(rewriter, loc, 0));
639     auto gep =
640         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
641 
642     // We do not expect the memref obtained using `memref.get_global` to be
643     // ever deallocated. Set the allocated pointer to be known bad value to
644     // help debug if that ever happens.
645     auto intPtrType = getIntPtrType(memSpace);
646     Value deadBeefConst =
647         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
648     auto deadBeefPtr =
649         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
650 
651     // Both allocated and aligned pointers are same. We could potentially stash
652     // a nullptr for the allocated pointer since we do not expect any dealloc.
653     return std::make_tuple(deadBeefPtr, gep);
654   }
655 };
656 
657 // Load operation is lowered to obtaining a pointer to the indexed element
658 // and loading it.
659 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
660   using Base::Base;
661 
662   LogicalResult
663   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
664                   ConversionPatternRewriter &rewriter) const override {
665     auto type = loadOp.getMemRefType();
666 
667     Value dataPtr = getStridedElementPtr(
668         loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter);
669     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
670     return success();
671   }
672 };
673 
674 // Store operation is lowered to obtaining a pointer to the indexed element,
675 // and storing the given value to it.
676 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
677   using Base::Base;
678 
679   LogicalResult
680   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
681                   ConversionPatternRewriter &rewriter) const override {
682     auto type = op.getMemRefType();
683 
684     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(),
685                                          adaptor.indices(), rewriter);
686     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr);
687     return success();
688   }
689 };
690 
691 // The prefetch operation is lowered in a way similar to the load operation
692 // except that the llvm.prefetch operation is used for replacement.
693 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
694   using Base::Base;
695 
696   LogicalResult
697   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
698                   ConversionPatternRewriter &rewriter) const override {
699     auto type = prefetchOp.getMemRefType();
700     auto loc = prefetchOp.getLoc();
701 
702     Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(),
703                                          adaptor.indices(), rewriter);
704 
705     // Replace with llvm.prefetch.
706     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
707     auto isWrite = rewriter.create<LLVM::ConstantOp>(
708         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
709     auto localityHint = rewriter.create<LLVM::ConstantOp>(
710         loc, llvmI32Type,
711         rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
712     auto isData = rewriter.create<LLVM::ConstantOp>(
713         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
714 
715     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
716                                                 localityHint, isData);
717     return success();
718   }
719 };
720 
721 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
722   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
723 
724   LogicalResult
725   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
726                   ConversionPatternRewriter &rewriter) const override {
727     Location loc = op.getLoc();
728     Type operandType = op.memref().getType();
729     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
730       UnrankedMemRefDescriptor desc(adaptor.memref());
731       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
732       return success();
733     }
734     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
735       rewriter.replaceOp(
736           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
737       return success();
738     }
739     return failure();
740   }
741 };
742 
743 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
744   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
745 
746   LogicalResult match(memref::CastOp memRefCastOp) const override {
747     Type srcType = memRefCastOp.getOperand().getType();
748     Type dstType = memRefCastOp.getType();
749 
750     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
751     // used for type erasure. For now they must preserve underlying element type
752     // and require source and result type to have the same rank. Therefore,
753     // perform a sanity check that the underlying structs are the same. Once op
754     // semantics are relaxed we can revisit.
755     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
756       return success(typeConverter->convertType(srcType) ==
757                      typeConverter->convertType(dstType));
758 
759     // At least one of the operands is unranked type
760     assert(srcType.isa<UnrankedMemRefType>() ||
761            dstType.isa<UnrankedMemRefType>());
762 
763     // Unranked to unranked cast is disallowed
764     return !(srcType.isa<UnrankedMemRefType>() &&
765              dstType.isa<UnrankedMemRefType>())
766                ? success()
767                : failure();
768   }
769 
770   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
771                ConversionPatternRewriter &rewriter) const override {
772     auto srcType = memRefCastOp.getOperand().getType();
773     auto dstType = memRefCastOp.getType();
774     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
775     auto loc = memRefCastOp.getLoc();
776 
777     // For ranked/ranked case, just keep the original descriptor.
778     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
779       return rewriter.replaceOp(memRefCastOp, {adaptor.source()});
780 
781     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
782       // Casting ranked to unranked memref type
783       // Set the rank in the destination from the memref type
784       // Allocate space on the stack and copy the src memref descriptor
785       // Set the ptr in the destination to the stack space
786       auto srcMemRefType = srcType.cast<MemRefType>();
787       int64_t rank = srcMemRefType.getRank();
788       // ptr = AllocaOp sizeof(MemRefDescriptor)
789       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
790           loc, adaptor.source(), rewriter);
791       // voidptr = BitCastOp srcType* to void*
792       auto voidPtr =
793           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
794               .getResult();
795       // rank = ConstantOp srcRank
796       auto rankVal = rewriter.create<LLVM::ConstantOp>(
797           loc, getIndexType(), rewriter.getIndexAttr(rank));
798       // undef = UndefOp
799       UnrankedMemRefDescriptor memRefDesc =
800           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
801       // d1 = InsertValueOp undef, rank, 0
802       memRefDesc.setRank(rewriter, loc, rankVal);
803       // d2 = InsertValueOp d1, voidptr, 1
804       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
805       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
806 
807     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
808       // Casting from unranked type to ranked.
809       // The operation is assumed to be doing a correct cast. If the destination
810       // type mismatches the unranked the type, it is undefined behavior.
811       UnrankedMemRefDescriptor memRefDesc(adaptor.source());
812       // ptr = ExtractValueOp src, 1
813       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
814       // castPtr = BitCastOp i8* to structTy*
815       auto castPtr =
816           rewriter
817               .create<LLVM::BitcastOp>(
818                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
819               .getResult();
820       // struct = LoadOp castPtr
821       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
822       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
823     } else {
824       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
825     }
826   }
827 };
828 
829 /// Pattern to lower a `memref.copy` to llvm.
830 ///
831 /// For memrefs with identity layouts, the copy is lowered to the llvm
832 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
833 /// to the generic `MemrefCopyFn`.
834 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
835   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
836 
837   LogicalResult
838   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
839                           ConversionPatternRewriter &rewriter) const {
840     auto loc = op.getLoc();
841     auto srcType = op.source().getType().dyn_cast<MemRefType>();
842 
843     MemRefDescriptor srcDesc(adaptor.source());
844 
845     // Compute number of elements.
846     Value numElements = rewriter.create<LLVM::ConstantOp>(
847         loc, getIndexType(), rewriter.getIndexAttr(1));
848     for (int pos = 0; pos < srcType.getRank(); ++pos) {
849       auto size = srcDesc.size(rewriter, loc, pos);
850       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
851     }
852 
853     // Get element size.
854     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
855     // Compute total.
856     Value totalSize =
857         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
858 
859     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
860     Value srcOffset = srcDesc.offset(rewriter, loc);
861     Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
862                                                 srcBasePtr, srcOffset);
863     MemRefDescriptor targetDesc(adaptor.target());
864     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
865     Value targetOffset = targetDesc.offset(rewriter, loc);
866     Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
867                                                    targetBasePtr, targetOffset);
868     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
869         loc, typeConverter->convertType(rewriter.getI1Type()),
870         rewriter.getBoolAttr(false));
871     rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
872                                     isVolatile);
873     rewriter.eraseOp(op);
874 
875     return success();
876   }
877 
878   LogicalResult
879   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
880                              ConversionPatternRewriter &rewriter) const {
881     auto loc = op.getLoc();
882     auto srcType = op.source().getType().cast<BaseMemRefType>();
883     auto targetType = op.target().getType().cast<BaseMemRefType>();
884 
885     // First make sure we have an unranked memref descriptor representation.
886     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
887       auto rank = rewriter.create<LLVM::ConstantOp>(
888           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
889       auto *typeConverter = getTypeConverter();
890       auto ptr =
891           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
892       auto voidPtr =
893           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
894               .getResult();
895       auto unrankedType =
896           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
897       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
898                                             unrankedType,
899                                             ValueRange{rank, voidPtr});
900     };
901 
902     Value unrankedSource = srcType.hasRank()
903                                ? makeUnranked(adaptor.source(), srcType)
904                                : adaptor.source();
905     Value unrankedTarget = targetType.hasRank()
906                                ? makeUnranked(adaptor.target(), targetType)
907                                : adaptor.target();
908 
909     // Now promote the unranked descriptors to the stack.
910     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
911                                                  rewriter.getIndexAttr(1));
912     auto promote = [&](Value desc) {
913       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
914       auto allocated =
915           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
916       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
917       return allocated;
918     };
919 
920     auto sourcePtr = promote(unrankedSource);
921     auto targetPtr = promote(unrankedTarget);
922 
923     unsigned typeSize =
924         mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
925     auto elemSize = rewriter.create<LLVM::ConstantOp>(
926         loc, getIndexType(), rewriter.getIndexAttr(typeSize));
927     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
928         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
929     rewriter.create<LLVM::CallOp>(loc, copyFn,
930                                   ValueRange{elemSize, sourcePtr, targetPtr});
931     rewriter.eraseOp(op);
932 
933     return success();
934   }
935 
936   LogicalResult
937   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
938                   ConversionPatternRewriter &rewriter) const override {
939     auto srcType = op.source().getType().cast<BaseMemRefType>();
940     auto targetType = op.target().getType().cast<BaseMemRefType>();
941 
942     auto isContiguousMemrefType = [](BaseMemRefType type) {
943       auto memrefType = type.dyn_cast<mlir::MemRefType>();
944       // We can use memcpy for memrefs if they have an identity layout or are
945       // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
946       // special case handled by memrefCopy.
947       return memrefType &&
948              (memrefType.getLayout().isIdentity() ||
949               (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
950                isStaticShapeAndContiguousRowMajor(memrefType)));
951     };
952 
953     if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
954       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
955 
956     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
957   }
958 };
959 
960 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
961 /// memref type. In unranked case, the fields are extracted from the underlying
962 /// ranked descriptor.
963 static void extractPointersAndOffset(Location loc,
964                                      ConversionPatternRewriter &rewriter,
965                                      LLVMTypeConverter &typeConverter,
966                                      Value originalOperand,
967                                      Value convertedOperand,
968                                      Value *allocatedPtr, Value *alignedPtr,
969                                      Value *offset = nullptr) {
970   Type operandType = originalOperand.getType();
971   if (operandType.isa<MemRefType>()) {
972     MemRefDescriptor desc(convertedOperand);
973     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
974     *alignedPtr = desc.alignedPtr(rewriter, loc);
975     if (offset != nullptr)
976       *offset = desc.offset(rewriter, loc);
977     return;
978   }
979 
980   unsigned memorySpace =
981       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
982   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
983   Type llvmElementType = typeConverter.convertType(elementType);
984   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
985       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
986 
987   // Extract pointer to the underlying ranked memref descriptor and cast it to
988   // ElemType**.
989   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
990   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
991 
992   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
993       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
994   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
995       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
996   if (offset != nullptr) {
997     *offset = UnrankedMemRefDescriptor::offset(
998         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
999   }
1000 }
1001 
1002 struct MemRefReinterpretCastOpLowering
1003     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1004   using ConvertOpToLLVMPattern<
1005       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1006 
1007   LogicalResult
1008   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1009                   ConversionPatternRewriter &rewriter) const override {
1010     Type srcType = castOp.source().getType();
1011 
1012     Value descriptor;
1013     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1014                                                adaptor, &descriptor)))
1015       return failure();
1016     rewriter.replaceOp(castOp, {descriptor});
1017     return success();
1018   }
1019 
1020 private:
1021   LogicalResult convertSourceMemRefToDescriptor(
1022       ConversionPatternRewriter &rewriter, Type srcType,
1023       memref::ReinterpretCastOp castOp,
1024       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1025     MemRefType targetMemRefType =
1026         castOp.getResult().getType().cast<MemRefType>();
1027     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1028                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1029     if (!llvmTargetDescriptorTy)
1030       return failure();
1031 
1032     // Create descriptor.
1033     Location loc = castOp.getLoc();
1034     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1035 
1036     // Set allocated and aligned pointers.
1037     Value allocatedPtr, alignedPtr;
1038     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1039                              castOp.source(), adaptor.source(), &allocatedPtr,
1040                              &alignedPtr);
1041     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1042     desc.setAlignedPtr(rewriter, loc, alignedPtr);
1043 
1044     // Set offset.
1045     if (castOp.isDynamicOffset(0))
1046       desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
1047     else
1048       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1049 
1050     // Set sizes and strides.
1051     unsigned dynSizeId = 0;
1052     unsigned dynStrideId = 0;
1053     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1054       if (castOp.isDynamicSize(i))
1055         desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
1056       else
1057         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1058 
1059       if (castOp.isDynamicStride(i))
1060         desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
1061       else
1062         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1063     }
1064     *descriptor = desc;
1065     return success();
1066   }
1067 };
1068 
1069 struct MemRefReshapeOpLowering
1070     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1071   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1072 
1073   LogicalResult
1074   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1075                   ConversionPatternRewriter &rewriter) const override {
1076     Type srcType = reshapeOp.source().getType();
1077 
1078     Value descriptor;
1079     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1080                                                adaptor, &descriptor)))
1081       return failure();
1082     rewriter.replaceOp(reshapeOp, {descriptor});
1083     return success();
1084   }
1085 
1086 private:
1087   LogicalResult
1088   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1089                                   Type srcType, memref::ReshapeOp reshapeOp,
1090                                   memref::ReshapeOp::Adaptor adaptor,
1091                                   Value *descriptor) const {
1092     // Conversion for statically-known shape args is performed via
1093     // `memref_reinterpret_cast`.
1094     auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
1095     if (shapeMemRefType.hasStaticShape())
1096       return failure();
1097 
1098     // The shape is a rank-1 tensor with unknown length.
1099     Location loc = reshapeOp.getLoc();
1100     MemRefDescriptor shapeDesc(adaptor.shape());
1101     Value resultRank = shapeDesc.size(rewriter, loc, 0);
1102 
1103     // Extract address space and element type.
1104     auto targetType =
1105         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1106     unsigned addressSpace = targetType.getMemorySpaceAsInt();
1107     Type elementType = targetType.getElementType();
1108 
1109     // Create the unranked memref descriptor that holds the ranked one. The
1110     // inner descriptor is allocated on stack.
1111     auto targetDesc = UnrankedMemRefDescriptor::undef(
1112         rewriter, loc, typeConverter->convertType(targetType));
1113     targetDesc.setRank(rewriter, loc, resultRank);
1114     SmallVector<Value, 4> sizes;
1115     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1116                                            targetDesc, sizes);
1117     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1118         loc, getVoidPtrType(), sizes.front(), llvm::None);
1119     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1120 
1121     // Extract pointers and offset from the source memref.
1122     Value allocatedPtr, alignedPtr, offset;
1123     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1124                              reshapeOp.source(), adaptor.source(),
1125                              &allocatedPtr, &alignedPtr, &offset);
1126 
1127     // Set pointers and offset.
1128     Type llvmElementType = typeConverter->convertType(elementType);
1129     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1130         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1131     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1132                                               elementPtrPtrType, allocatedPtr);
1133     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1134                                             underlyingDescPtr,
1135                                             elementPtrPtrType, alignedPtr);
1136     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1137                                         underlyingDescPtr, elementPtrPtrType,
1138                                         offset);
1139 
1140     // Use the offset pointer as base for further addressing. Copy over the new
1141     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1142     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1143         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1144         elementPtrPtrType);
1145     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1146         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1147     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1148     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1149     Value resultRankMinusOne =
1150         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1151 
1152     Block *initBlock = rewriter.getInsertionBlock();
1153     Type indexType = getTypeConverter()->getIndexType();
1154     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1155 
1156     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1157                                             {indexType, indexType}, {loc, loc});
1158 
1159     // Move the remaining initBlock ops to condBlock.
1160     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1161     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1162 
1163     rewriter.setInsertionPointToEnd(initBlock);
1164     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1165                                 condBlock);
1166     rewriter.setInsertionPointToStart(condBlock);
1167     Value indexArg = condBlock->getArgument(0);
1168     Value strideArg = condBlock->getArgument(1);
1169 
1170     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1171     Value pred = rewriter.create<LLVM::ICmpOp>(
1172         loc, IntegerType::get(rewriter.getContext(), 1),
1173         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1174 
1175     Block *bodyBlock =
1176         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1177     rewriter.setInsertionPointToStart(bodyBlock);
1178 
1179     // Copy size from shape to descriptor.
1180     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1181     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1182         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1183     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1184     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1185                                       targetSizesBase, indexArg, size);
1186 
1187     // Write stride value and compute next one.
1188     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1189                                         targetStridesBase, indexArg, strideArg);
1190     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1191 
1192     // Decrement loop counter and branch back.
1193     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1194     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1195                                 condBlock);
1196 
1197     Block *remainder =
1198         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1199 
1200     // Hook up the cond exit to the remainder.
1201     rewriter.setInsertionPointToEnd(condBlock);
1202     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1203                                     llvm::None);
1204 
1205     // Reset position to beginning of new remainder block.
1206     rewriter.setInsertionPointToStart(remainder);
1207 
1208     *descriptor = targetDesc;
1209     return success();
1210   }
1211 };
1212 
1213 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1214 /// `Value`s.
1215 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1216                                       Type &llvmIndexType,
1217                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1218   return llvm::to_vector<4>(
1219       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1220         if (auto attr = value.dyn_cast<Attribute>())
1221           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1222         return value.get<Value>();
1223       }));
1224 }
1225 
1226 /// Compute a map that for a given dimension of the expanded type gives the
1227 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1228 /// the `reassocation` maps.
1229 static DenseMap<int64_t, int64_t>
1230 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1231   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1232   for (auto &en : enumerate(reassociation)) {
1233     for (auto dim : en.value())
1234       expandedDimToCollapsedDim[dim] = en.index();
1235   }
1236   return expandedDimToCollapsedDim;
1237 }
1238 
1239 static OpFoldResult
1240 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1241                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1242                          MemRefDescriptor &inDesc,
1243                          ArrayRef<int64_t> inStaticShape,
1244                          ArrayRef<ReassociationIndices> reassocation,
1245                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1246   int64_t outDimSize = outStaticShape[outDimIndex];
1247   if (!ShapedType::isDynamic(outDimSize))
1248     return b.getIndexAttr(outDimSize);
1249 
1250   // Calculate the multiplication of all the out dim sizes except the
1251   // current dim.
1252   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1253   int64_t otherDimSizesMul = 1;
1254   for (auto otherDimIndex : reassocation[inDimIndex]) {
1255     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1256       continue;
1257     int64_t otherDimSize = outStaticShape[otherDimIndex];
1258     assert(!ShapedType::isDynamic(otherDimSize) &&
1259            "single dimension cannot be expanded into multiple dynamic "
1260            "dimensions");
1261     otherDimSizesMul *= otherDimSize;
1262   }
1263 
1264   // outDimSize = inDimSize / otherOutDimSizesMul
1265   int64_t inDimSize = inStaticShape[inDimIndex];
1266   Value inDimSizeDynamic =
1267       ShapedType::isDynamic(inDimSize)
1268           ? inDesc.size(b, loc, inDimIndex)
1269           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1270                                        b.getIndexAttr(inDimSize));
1271   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1272       loc, inDimSizeDynamic,
1273       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1274                                  b.getIndexAttr(otherDimSizesMul)));
1275   return outDimSizeDynamic;
1276 }
1277 
1278 static OpFoldResult getCollapsedOutputDimSize(
1279     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1280     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1281     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1282   if (!ShapedType::isDynamic(outDimSize))
1283     return b.getIndexAttr(outDimSize);
1284 
1285   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1286   Value outDimSizeDynamic = c1;
1287   for (auto inDimIndex : reassocation[outDimIndex]) {
1288     int64_t inDimSize = inStaticShape[inDimIndex];
1289     Value inDimSizeDynamic =
1290         ShapedType::isDynamic(inDimSize)
1291             ? inDesc.size(b, loc, inDimIndex)
1292             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1293                                          b.getIndexAttr(inDimSize));
1294     outDimSizeDynamic =
1295         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1296   }
1297   return outDimSizeDynamic;
1298 }
1299 
1300 static SmallVector<OpFoldResult, 4>
1301 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1302                         ArrayRef<ReassociationIndices> reassocation,
1303                         ArrayRef<int64_t> inStaticShape,
1304                         MemRefDescriptor &inDesc,
1305                         ArrayRef<int64_t> outStaticShape) {
1306   return llvm::to_vector<4>(llvm::map_range(
1307       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1308         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1309                                          outStaticShape[outDimIndex],
1310                                          inStaticShape, inDesc, reassocation);
1311       }));
1312 }
1313 
1314 static SmallVector<OpFoldResult, 4>
1315 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1316                        ArrayRef<ReassociationIndices> reassocation,
1317                        ArrayRef<int64_t> inStaticShape,
1318                        MemRefDescriptor &inDesc,
1319                        ArrayRef<int64_t> outStaticShape) {
1320   DenseMap<int64_t, int64_t> outDimToInDimMap =
1321       getExpandedDimToCollapsedDimMap(reassocation);
1322   return llvm::to_vector<4>(llvm::map_range(
1323       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1324         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1325                                         outStaticShape, inDesc, inStaticShape,
1326                                         reassocation, outDimToInDimMap);
1327       }));
1328 }
1329 
1330 static SmallVector<Value>
1331 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1332                       ArrayRef<ReassociationIndices> reassocation,
1333                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1334                       ArrayRef<int64_t> outStaticShape) {
1335   return outStaticShape.size() < inStaticShape.size()
1336              ? getAsValues(b, loc, llvmIndexType,
1337                            getCollapsedOutputShape(b, loc, llvmIndexType,
1338                                                    reassocation, inStaticShape,
1339                                                    inDesc, outStaticShape))
1340              : getAsValues(b, loc, llvmIndexType,
1341                            getExpandedOutputShape(b, loc, llvmIndexType,
1342                                                   reassocation, inStaticShape,
1343                                                   inDesc, outStaticShape));
1344 }
1345 
1346 // ReshapeOp creates a new view descriptor of the proper rank.
1347 // For now, the only conversion supported is for target MemRef with static sizes
1348 // and strides.
1349 template <typename ReshapeOp>
1350 class ReassociatingReshapeOpConversion
1351     : public ConvertOpToLLVMPattern<ReshapeOp> {
1352 public:
1353   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1354   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1355 
1356   LogicalResult
1357   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1358                   ConversionPatternRewriter &rewriter) const override {
1359     MemRefType dstType = reshapeOp.getResultType();
1360     MemRefType srcType = reshapeOp.getSrcType();
1361 
1362     // The condition on the layouts can be ignored when all shapes are static.
1363     if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) {
1364       if (!srcType.getLayout().isIdentity() ||
1365           !dstType.getLayout().isIdentity()) {
1366         return rewriter.notifyMatchFailure(
1367             reshapeOp, "only empty layout map is supported");
1368       }
1369     }
1370 
1371     int64_t offset;
1372     SmallVector<int64_t, 4> strides;
1373     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1374       return rewriter.notifyMatchFailure(
1375           reshapeOp, "failed to get stride and offset exprs");
1376     }
1377 
1378     MemRefDescriptor srcDesc(adaptor.src());
1379     Location loc = reshapeOp->getLoc();
1380     auto dstDesc = MemRefDescriptor::undef(
1381         rewriter, loc, this->typeConverter->convertType(dstType));
1382     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1383     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1384     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1385 
1386     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1387     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1388     Type llvmIndexType =
1389         this->typeConverter->convertType(rewriter.getIndexType());
1390     SmallVector<Value> dstShape = getDynamicOutputShape(
1391         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1392         srcStaticShape, srcDesc, dstStaticShape);
1393     for (auto &en : llvm::enumerate(dstShape))
1394       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1395 
1396     auto isStaticStride = [](int64_t stride) {
1397       return !ShapedType::isDynamicStrideOrOffset(stride);
1398     };
1399     if (llvm::all_of(strides, isStaticStride)) {
1400       for (auto &en : llvm::enumerate(strides))
1401         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1402     } else {
1403       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1404                                                    rewriter.getIndexAttr(1));
1405       Value stride = c1;
1406       for (auto dimIndex :
1407            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1408         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1409         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1410       }
1411     }
1412     rewriter.replaceOp(reshapeOp, {dstDesc});
1413     return success();
1414   }
1415 };
1416 
1417 /// Conversion pattern that transforms a subview op into:
1418 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1419 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1420 ///      and stride.
1421 /// The subview op is replaced by the descriptor.
1422 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1423   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1424 
1425   LogicalResult
1426   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1427                   ConversionPatternRewriter &rewriter) const override {
1428     auto loc = subViewOp.getLoc();
1429 
1430     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
1431     auto sourceElementTy =
1432         typeConverter->convertType(sourceMemRefType.getElementType());
1433 
1434     auto viewMemRefType = subViewOp.getType();
1435     auto inferredType = memref::SubViewOp::inferResultType(
1436                             subViewOp.getSourceType(),
1437                             extractFromI64ArrayAttr(subViewOp.static_offsets()),
1438                             extractFromI64ArrayAttr(subViewOp.static_sizes()),
1439                             extractFromI64ArrayAttr(subViewOp.static_strides()))
1440                             .cast<MemRefType>();
1441     auto targetElementTy =
1442         typeConverter->convertType(viewMemRefType.getElementType());
1443     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1444     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1445         !LLVM::isCompatibleType(sourceElementTy) ||
1446         !LLVM::isCompatibleType(targetElementTy) ||
1447         !LLVM::isCompatibleType(targetDescTy))
1448       return failure();
1449 
1450     // Extract the offset and strides from the type.
1451     int64_t offset;
1452     SmallVector<int64_t, 4> strides;
1453     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1454     if (failed(successStrides))
1455       return failure();
1456 
1457     // Create the descriptor.
1458     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1459       return failure();
1460     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1461     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1462 
1463     // Copy the buffer pointer from the old descriptor to the new one.
1464     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1465     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1466         loc,
1467         LLVM::LLVMPointerType::get(targetElementTy,
1468                                    viewMemRefType.getMemorySpaceAsInt()),
1469         extracted);
1470     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1471 
1472     // Copy the aligned pointer from the old descriptor to the new one.
1473     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1474     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1475         loc,
1476         LLVM::LLVMPointerType::get(targetElementTy,
1477                                    viewMemRefType.getMemorySpaceAsInt()),
1478         extracted);
1479     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1480 
1481     size_t inferredShapeRank = inferredType.getRank();
1482     size_t resultShapeRank = viewMemRefType.getRank();
1483 
1484     // Extract strides needed to compute offset.
1485     SmallVector<Value, 4> strideValues;
1486     strideValues.reserve(inferredShapeRank);
1487     for (unsigned i = 0; i < inferredShapeRank; ++i)
1488       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1489 
1490     // Offset.
1491     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1492     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1493       targetMemRef.setConstantOffset(rewriter, loc, offset);
1494     } else {
1495       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1496       // `inferredShapeRank` may be larger than the number of offset operands
1497       // because of trailing semantics. In this case, the offset is guaranteed
1498       // to be interpreted as 0 and we can just skip the extra dimensions.
1499       for (unsigned i = 0, e = std::min(inferredShapeRank,
1500                                         subViewOp.getMixedOffsets().size());
1501            i < e; ++i) {
1502         Value offset =
1503             // TODO: need OpFoldResult ODS adaptor to clean this up.
1504             subViewOp.isDynamicOffset(i)
1505                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1506                 : rewriter.create<LLVM::ConstantOp>(
1507                       loc, llvmIndexType,
1508                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1509         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1510         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1511       }
1512       targetMemRef.setOffset(rewriter, loc, baseOffset);
1513     }
1514 
1515     // Update sizes and strides.
1516     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1517     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1518     assert(mixedSizes.size() == mixedStrides.size() &&
1519            "expected sizes and strides of equal length");
1520     llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1521     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1522          i >= 0 && j >= 0; --i) {
1523       if (unusedDims.test(i))
1524         continue;
1525 
1526       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1527       // In this case, the size is guaranteed to be interpreted as Dim and the
1528       // stride as 1.
1529       Value size, stride;
1530       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1531         // If the static size is available, use it directly. This is similar to
1532         // the folding of dim(constant-op) but removes the need for dim to be
1533         // aware of LLVM constants and for this pass to be aware of std
1534         // constants.
1535         int64_t staticSize =
1536             subViewOp.source().getType().cast<MemRefType>().getShape()[i];
1537         if (staticSize != ShapedType::kDynamicSize) {
1538           size = rewriter.create<LLVM::ConstantOp>(
1539               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1540         } else {
1541           Value pos = rewriter.create<LLVM::ConstantOp>(
1542               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1543           Value dim =
1544               rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos);
1545           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1546               loc, llvmIndexType, dim);
1547           size = cast.getResult(0);
1548         }
1549         stride = rewriter.create<LLVM::ConstantOp>(
1550             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1551       } else {
1552         // TODO: need OpFoldResult ODS adaptor to clean this up.
1553         size =
1554             subViewOp.isDynamicSize(i)
1555                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1556                 : rewriter.create<LLVM::ConstantOp>(
1557                       loc, llvmIndexType,
1558                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1559         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1560           stride = rewriter.create<LLVM::ConstantOp>(
1561               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1562         } else {
1563           stride =
1564               subViewOp.isDynamicStride(i)
1565                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1566                   : rewriter.create<LLVM::ConstantOp>(
1567                         loc, llvmIndexType,
1568                         rewriter.getI64IntegerAttr(
1569                             subViewOp.getStaticStride(i)));
1570           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1571         }
1572       }
1573       targetMemRef.setSize(rewriter, loc, j, size);
1574       targetMemRef.setStride(rewriter, loc, j, stride);
1575       j--;
1576     }
1577 
1578     rewriter.replaceOp(subViewOp, {targetMemRef});
1579     return success();
1580   }
1581 };
1582 
1583 /// Conversion pattern that transforms a transpose op into:
1584 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1585 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1586 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1587 ///      and stride. Size and stride are permutations of the original values.
1588 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1589 /// The transpose op is replaced by the alloca'ed pointer.
1590 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1591 public:
1592   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1593 
1594   LogicalResult
1595   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1596                   ConversionPatternRewriter &rewriter) const override {
1597     auto loc = transposeOp.getLoc();
1598     MemRefDescriptor viewMemRef(adaptor.in());
1599 
1600     // No permutation, early exit.
1601     if (transposeOp.permutation().isIdentity())
1602       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1603 
1604     auto targetMemRef = MemRefDescriptor::undef(
1605         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1606 
1607     // Copy the base and aligned pointers from the old descriptor to the new
1608     // one.
1609     targetMemRef.setAllocatedPtr(rewriter, loc,
1610                                  viewMemRef.allocatedPtr(rewriter, loc));
1611     targetMemRef.setAlignedPtr(rewriter, loc,
1612                                viewMemRef.alignedPtr(rewriter, loc));
1613 
1614     // Copy the offset pointer from the old descriptor to the new one.
1615     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1616 
1617     // Iterate over the dimensions and apply size/stride permutation.
1618     for (const auto &en :
1619          llvm::enumerate(transposeOp.permutation().getResults())) {
1620       int sourcePos = en.index();
1621       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1622       targetMemRef.setSize(rewriter, loc, targetPos,
1623                            viewMemRef.size(rewriter, loc, sourcePos));
1624       targetMemRef.setStride(rewriter, loc, targetPos,
1625                              viewMemRef.stride(rewriter, loc, sourcePos));
1626     }
1627 
1628     rewriter.replaceOp(transposeOp, {targetMemRef});
1629     return success();
1630   }
1631 };
1632 
1633 /// Conversion pattern that transforms an op into:
1634 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1635 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1636 ///      and stride.
1637 /// The view op is replaced by the descriptor.
1638 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1639   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1640 
1641   // Build and return the value for the idx^th shape dimension, either by
1642   // returning the constant shape dimension or counting the proper dynamic size.
1643   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1644                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1645                 unsigned idx) const {
1646     assert(idx < shape.size());
1647     if (!ShapedType::isDynamic(shape[idx]))
1648       return createIndexConstant(rewriter, loc, shape[idx]);
1649     // Count the number of dynamic dims in range [0, idx]
1650     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1651       return ShapedType::isDynamic(v);
1652     });
1653     return dynamicSizes[nDynamic];
1654   }
1655 
1656   // Build and return the idx^th stride, either by returning the constant stride
1657   // or by computing the dynamic stride from the current `runningStride` and
1658   // `nextSize`. The caller should keep a running stride and update it with the
1659   // result returned by this function.
1660   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1661                   ArrayRef<int64_t> strides, Value nextSize,
1662                   Value runningStride, unsigned idx) const {
1663     assert(idx < strides.size());
1664     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1665       return createIndexConstant(rewriter, loc, strides[idx]);
1666     if (nextSize)
1667       return runningStride
1668                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1669                  : nextSize;
1670     assert(!runningStride);
1671     return createIndexConstant(rewriter, loc, 1);
1672   }
1673 
1674   LogicalResult
1675   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1676                   ConversionPatternRewriter &rewriter) const override {
1677     auto loc = viewOp.getLoc();
1678 
1679     auto viewMemRefType = viewOp.getType();
1680     auto targetElementTy =
1681         typeConverter->convertType(viewMemRefType.getElementType());
1682     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1683     if (!targetDescTy || !targetElementTy ||
1684         !LLVM::isCompatibleType(targetElementTy) ||
1685         !LLVM::isCompatibleType(targetDescTy))
1686       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1687              failure();
1688 
1689     int64_t offset;
1690     SmallVector<int64_t, 4> strides;
1691     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1692     if (failed(successStrides))
1693       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1694     assert(offset == 0 && "expected offset to be 0");
1695 
1696     // Create the descriptor.
1697     MemRefDescriptor sourceMemRef(adaptor.source());
1698     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1699 
1700     // Field 1: Copy the allocated pointer, used for malloc/free.
1701     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1702     auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
1703     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1704         loc,
1705         LLVM::LLVMPointerType::get(targetElementTy,
1706                                    srcMemRefType.getMemorySpaceAsInt()),
1707         allocatedPtr);
1708     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1709 
1710     // Field 2: Copy the actual aligned pointer to payload.
1711     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1712     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
1713                                               alignedPtr, adaptor.byte_shift());
1714     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1715         loc,
1716         LLVM::LLVMPointerType::get(targetElementTy,
1717                                    srcMemRefType.getMemorySpaceAsInt()),
1718         alignedPtr);
1719     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1720 
1721     // Field 3: The offset in the resulting type must be 0. This is because of
1722     // the type change: an offset on srcType* may not be expressible as an
1723     // offset on dstType*.
1724     targetMemRef.setOffset(rewriter, loc,
1725                            createIndexConstant(rewriter, loc, offset));
1726 
1727     // Early exit for 0-D corner case.
1728     if (viewMemRefType.getRank() == 0)
1729       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1730 
1731     // Fields 4 and 5: Update sizes and strides.
1732     if (strides.back() != 1)
1733       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1734              failure();
1735     Value stride = nullptr, nextSize = nullptr;
1736     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1737       // Update size.
1738       Value size =
1739           getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
1740       targetMemRef.setSize(rewriter, loc, i, size);
1741       // Update stride.
1742       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1743       targetMemRef.setStride(rewriter, loc, i, stride);
1744       nextSize = size;
1745     }
1746 
1747     rewriter.replaceOp(viewOp, {targetMemRef});
1748     return success();
1749   }
1750 };
1751 
1752 //===----------------------------------------------------------------------===//
1753 // AtomicRMWOpLowering
1754 //===----------------------------------------------------------------------===//
1755 
1756 /// Try to match the kind of a std.atomic_rmw to determine whether to use a
1757 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1758 static Optional<LLVM::AtomicBinOp>
1759 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1760   switch (atomicOp.kind()) {
1761   case arith::AtomicRMWKind::addf:
1762     return LLVM::AtomicBinOp::fadd;
1763   case arith::AtomicRMWKind::addi:
1764     return LLVM::AtomicBinOp::add;
1765   case arith::AtomicRMWKind::assign:
1766     return LLVM::AtomicBinOp::xchg;
1767   case arith::AtomicRMWKind::maxs:
1768     return LLVM::AtomicBinOp::max;
1769   case arith::AtomicRMWKind::maxu:
1770     return LLVM::AtomicBinOp::umax;
1771   case arith::AtomicRMWKind::mins:
1772     return LLVM::AtomicBinOp::min;
1773   case arith::AtomicRMWKind::minu:
1774     return LLVM::AtomicBinOp::umin;
1775   case arith::AtomicRMWKind::ori:
1776     return LLVM::AtomicBinOp::_or;
1777   case arith::AtomicRMWKind::andi:
1778     return LLVM::AtomicBinOp::_and;
1779   default:
1780     return llvm::None;
1781   }
1782   llvm_unreachable("Invalid AtomicRMWKind");
1783 }
1784 
1785 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1786   using Base::Base;
1787 
1788   LogicalResult
1789   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1790                   ConversionPatternRewriter &rewriter) const override {
1791     if (failed(match(atomicOp)))
1792       return failure();
1793     auto maybeKind = matchSimpleAtomicOp(atomicOp);
1794     if (!maybeKind)
1795       return failure();
1796     auto resultType = adaptor.value().getType();
1797     auto memRefType = atomicOp.getMemRefType();
1798     auto dataPtr =
1799         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
1800                              adaptor.indices(), rewriter);
1801     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1802         atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
1803         LLVM::AtomicOrdering::acq_rel);
1804     return success();
1805   }
1806 };
1807 
1808 } // namespace
1809 
1810 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
1811                                                   RewritePatternSet &patterns) {
1812   // clang-format off
1813   patterns.add<
1814       AllocaOpLowering,
1815       AllocaScopeOpLowering,
1816       AtomicRMWOpLowering,
1817       AssumeAlignmentOpLowering,
1818       DimOpLowering,
1819       GenericAtomicRMWOpLowering,
1820       GlobalMemrefOpLowering,
1821       GetGlobalMemrefOpLowering,
1822       LoadOpLowering,
1823       MemRefCastOpLowering,
1824       MemRefCopyOpLowering,
1825       MemRefReinterpretCastOpLowering,
1826       MemRefReshapeOpLowering,
1827       PrefetchOpLowering,
1828       RankOpLowering,
1829       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1830       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1831       StoreOpLowering,
1832       SubViewOpLowering,
1833       TransposeOpLowering,
1834       ViewOpLowering>(converter);
1835   // clang-format on
1836   auto allocLowering = converter.getOptions().allocLowering;
1837   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1838     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1839   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1840     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1841 }
1842 
1843 namespace {
1844 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
1845   MemRefToLLVMPass() = default;
1846 
1847   void runOnOperation() override {
1848     Operation *op = getOperation();
1849     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1850     LowerToLLVMOptions options(&getContext(),
1851                                dataLayoutAnalysis.getAtOrAbove(op));
1852     options.allocLowering =
1853         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1854                          : LowerToLLVMOptions::AllocLowering::Malloc);
1855     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1856       options.overrideIndexBitwidth(indexBitwidth);
1857 
1858     LLVMTypeConverter typeConverter(&getContext(), options,
1859                                     &dataLayoutAnalysis);
1860     RewritePatternSet patterns(&getContext());
1861     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
1862     LLVMConversionTarget target(getContext());
1863     target.addLegalOp<FuncOp>();
1864     if (failed(applyPartialConversion(op, target, std::move(patterns))))
1865       signalPassFailure();
1866   }
1867 };
1868 } // namespace
1869 
1870 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
1871   return std::make_unique<MemRefToLLVMPass>();
1872 }
1873