1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 
isStaticStrideOrOffset(int64_t strideOrOffset)29 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
30   return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
31 }
32 
33 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
AllocOpLowering__anon7a9e10510111::AllocOpLowering34   AllocOpLowering(LLVMTypeConverter &converter)
35       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
36                                 converter) {}
37 
getAllocFn__anon7a9e10510111::AllocOpLowering38   LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
39     bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
40 
41     if (useGenericFn)
42       return LLVM::lookupOrCreateGenericAllocFn(module, getIndexType());
43 
44     return LLVM::lookupOrCreateMallocFn(module, getIndexType());
45   }
46 
allocateBuffer__anon7a9e10510111::AllocOpLowering47   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
48                                           Location loc, Value sizeBytes,
49                                           Operation *op) const override {
50     // Heap allocations.
51     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
52     MemRefType memRefType = allocOp.getType();
53 
54     Value alignment;
55     if (auto alignmentAttr = allocOp.getAlignment()) {
56       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
57     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
58       // In the case where no alignment is specified, we may want to override
59       // `malloc's` behavior. `malloc` typically aligns at the size of the
60       // biggest scalar on a target HW. For non-scalars, use the natural
61       // alignment of the LLVM type given by the LLVM DataLayout.
62       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
63     }
64 
65     if (alignment) {
66       // Adjust the allocation size to consider alignment.
67       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
68     }
69 
70     // Allocate the underlying buffer and store a pointer to it in the MemRef
71     // descriptor.
72     Type elementPtrType = this->getElementPtrType(memRefType);
73     auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
74     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
75                                   getVoidPtrType());
76     Value allocatedPtr =
77         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
78 
79     Value alignedPtr = allocatedPtr;
80     if (alignment) {
81       // Compute the aligned type pointer.
82       Value allocatedInt =
83           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
84       Value alignmentInt =
85           createAligned(rewriter, loc, allocatedInt, alignment);
86       alignedPtr =
87           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
88     }
89 
90     return std::make_tuple(allocatedPtr, alignedPtr);
91   }
92 };
93 
94 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
AlignedAllocOpLowering__anon7a9e10510111::AlignedAllocOpLowering95   AlignedAllocOpLowering(LLVMTypeConverter &converter)
96       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
97                                 converter) {}
98 
99   /// Returns the memref's element size in bytes using the data layout active at
100   /// `op`.
101   // TODO: there are other places where this is used. Expose publicly?
getMemRefEltSizeInBytes__anon7a9e10510111::AlignedAllocOpLowering102   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
103     const DataLayout *layout = &defaultLayout;
104     if (const DataLayoutAnalysis *analysis =
105             getTypeConverter()->getDataLayoutAnalysis()) {
106       layout = &analysis->getAbove(op);
107     }
108     Type elementType = memRefType.getElementType();
109     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
110       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
111                                                          *layout);
112     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
113       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
114           memRefElementType, *layout);
115     return layout->getTypeSize(elementType);
116   }
117 
118   /// Returns true if the memref size in bytes is known to be a multiple of
119   /// factor assuming the data layout active at `op`.
isMemRefSizeMultipleOf__anon7a9e10510111::AlignedAllocOpLowering120   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
121                               Operation *op) const {
122     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
123     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
124       if (ShapedType::isDynamic(type.getDimSize(i)))
125         continue;
126       sizeDivisor = sizeDivisor * type.getDimSize(i);
127     }
128     return sizeDivisor % factor == 0;
129   }
130 
131   /// Returns the alignment to be used for the allocation call itself.
132   /// aligned_alloc requires the allocation size to be a power of two, and the
133   /// allocation size to be a multiple of alignment,
getAllocationAlignment__anon7a9e10510111::AlignedAllocOpLowering134   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
135     if (Optional<uint64_t> alignment = allocOp.getAlignment())
136       return *alignment;
137 
138     // Whenever we don't have alignment set, we will use an alignment
139     // consistent with the element type; since the allocation size has to be a
140     // power of two, we will bump to the next power of two if it already isn't.
141     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
142     return std::max(kMinAlignedAllocAlignment,
143                     llvm::PowerOf2Ceil(eltSizeBytes));
144   }
145 
getAllocFn__anon7a9e10510111::AlignedAllocOpLowering146   LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
147     bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
148 
149     if (useGenericFn)
150       return LLVM::lookupOrCreateGenericAlignedAllocFn(module, getIndexType());
151 
152     return LLVM::lookupOrCreateAlignedAllocFn(module, getIndexType());
153   }
154 
allocateBuffer__anon7a9e10510111::AlignedAllocOpLowering155   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
156                                           Location loc, Value sizeBytes,
157                                           Operation *op) const override {
158     // Heap allocations.
159     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
160     MemRefType memRefType = allocOp.getType();
161     int64_t alignment = getAllocationAlignment(allocOp);
162     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
163 
164     // aligned_alloc requires size to be a multiple of alignment; we will pad
165     // the size to the next multiple if necessary.
166     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
167       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
168 
169     Type elementPtrType = this->getElementPtrType(memRefType);
170     auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
171     auto results =
172         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
173                        getVoidPtrType());
174     Value allocatedPtr =
175         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
176 
177     return std::make_tuple(allocatedPtr, allocatedPtr);
178   }
179 
180   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
181   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
182 
183   /// Default layout to use in absence of the corresponding analysis.
184   DataLayout defaultLayout;
185 };
186 
187 // Out of line definition, required till C++17.
188 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
189 
190 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
AllocaOpLowering__anon7a9e10510111::AllocaOpLowering191   AllocaOpLowering(LLVMTypeConverter &converter)
192       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
193                                 converter) {}
194 
195   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
196   /// is set to null for stack allocations. `accessAlignment` is set if
197   /// alignment is needed post allocation (for eg. in conjunction with malloc).
allocateBuffer__anon7a9e10510111::AllocaOpLowering198   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
199                                           Location loc, Value sizeBytes,
200                                           Operation *op) const override {
201 
202     // With alloca, one gets a pointer to the element type right away.
203     // For stack allocations.
204     auto allocaOp = cast<memref::AllocaOp>(op);
205     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
206 
207     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
208         loc, elementPtrType, sizeBytes, allocaOp.getAlignment().value_or(0));
209 
210     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
211   }
212 };
213 
214 struct AllocaScopeOpLowering
215     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
216   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
217 
218   LogicalResult
matchAndRewrite__anon7a9e10510111::AllocaScopeOpLowering219   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
220                   ConversionPatternRewriter &rewriter) const override {
221     OpBuilder::InsertionGuard guard(rewriter);
222     Location loc = allocaScopeOp.getLoc();
223 
224     // Split the current block before the AllocaScopeOp to create the inlining
225     // point.
226     auto *currentBlock = rewriter.getInsertionBlock();
227     auto *remainingOpsBlock =
228         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
229     Block *continueBlock;
230     if (allocaScopeOp.getNumResults() == 0) {
231       continueBlock = remainingOpsBlock;
232     } else {
233       continueBlock = rewriter.createBlock(
234           remainingOpsBlock, allocaScopeOp.getResultTypes(),
235           SmallVector<Location>(allocaScopeOp->getNumResults(),
236                                 allocaScopeOp.getLoc()));
237       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
238     }
239 
240     // Inline body region.
241     Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
242     Block *afterBody = &allocaScopeOp.getBodyRegion().back();
243     rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
244 
245     // Save stack and then branch into the body of the region.
246     rewriter.setInsertionPointToEnd(currentBlock);
247     auto stackSaveOp =
248         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
249     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
250 
251     // Replace the alloca_scope return with a branch that jumps out of the body.
252     // Stack restore before leaving the body region.
253     rewriter.setInsertionPointToEnd(afterBody);
254     auto returnOp =
255         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
256     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
257         returnOp, returnOp.getResults(), continueBlock);
258 
259     // Insert stack restore before jumping out the body of the region.
260     rewriter.setInsertionPoint(branchOp);
261     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
262 
263     // Replace the op with values return from the body region.
264     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
265 
266     return success();
267   }
268 };
269 
270 struct AssumeAlignmentOpLowering
271     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
272   using ConvertOpToLLVMPattern<
273       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
274 
275   LogicalResult
matchAndRewrite__anon7a9e10510111::AssumeAlignmentOpLowering276   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
277                   ConversionPatternRewriter &rewriter) const override {
278     Value memref = adaptor.getMemref();
279     unsigned alignment = op.getAlignment();
280     auto loc = op.getLoc();
281 
282     MemRefDescriptor memRefDescriptor(memref);
283     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
284 
285     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
286     // the asserted memref.alignedPtr isn't used anywhere else, as the real
287     // users like load/store/views always re-extract memref.alignedPtr as they
288     // get lowered.
289     //
290     // This relies on LLVM's CSE optimization (potentially after SROA), since
291     // after CSE all memref.alignedPtr instances get de-duplicated into the same
292     // pointer SSA value.
293     auto intPtrType =
294         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
295     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
296     Value mask =
297         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
298     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
299     rewriter.create<LLVM::AssumeOp>(
300         loc, rewriter.create<LLVM::ICmpOp>(
301                  loc, LLVM::ICmpPredicate::eq,
302                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
303 
304     rewriter.eraseOp(op);
305     return success();
306   }
307 };
308 
309 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
310 // The memref descriptor being an SSA value, there is no need to clean it up
311 // in any way.
312 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
313   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
314 
DeallocOpLowering__anon7a9e10510111::DeallocOpLowering315   explicit DeallocOpLowering(LLVMTypeConverter &converter)
316       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
317 
getFreeFn__anon7a9e10510111::DeallocOpLowering318   LLVM::LLVMFuncOp getFreeFn(ModuleOp module) const {
319     bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
320 
321     if (useGenericFn)
322       return LLVM::lookupOrCreateGenericFreeFn(module);
323 
324     return LLVM::lookupOrCreateFreeFn(module);
325   }
326 
327   LogicalResult
matchAndRewrite__anon7a9e10510111::DeallocOpLowering328   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
329                   ConversionPatternRewriter &rewriter) const override {
330     // Insert the `free` declaration if it is not already present.
331     auto freeFunc = getFreeFn(op->getParentOfType<ModuleOp>());
332     MemRefDescriptor memref(adaptor.getMemref());
333     Value casted = rewriter.create<LLVM::BitcastOp>(
334         op.getLoc(), getVoidPtrType(),
335         memref.allocatedPtr(rewriter, op.getLoc()));
336     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
337         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
338     return success();
339   }
340 };
341 
342 // A `dim` is converted to a constant for static sizes and to an access to the
343 // size stored in the memref descriptor for dynamic sizes.
344 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
345   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
346 
347   LogicalResult
matchAndRewrite__anon7a9e10510111::DimOpLowering348   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
349                   ConversionPatternRewriter &rewriter) const override {
350     Type operandType = dimOp.getSource().getType();
351     if (operandType.isa<UnrankedMemRefType>()) {
352       rewriter.replaceOp(
353           dimOp, {extractSizeOfUnrankedMemRef(
354                      operandType, dimOp, adaptor.getOperands(), rewriter)});
355 
356       return success();
357     }
358     if (operandType.isa<MemRefType>()) {
359       rewriter.replaceOp(
360           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
361                                             adaptor.getOperands(), rewriter)});
362       return success();
363     }
364     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
365   }
366 
367 private:
extractSizeOfUnrankedMemRef__anon7a9e10510111::DimOpLowering368   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
369                                     OpAdaptor adaptor,
370                                     ConversionPatternRewriter &rewriter) const {
371     Location loc = dimOp.getLoc();
372 
373     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
374     auto scalarMemRefType =
375         MemRefType::get({}, unrankedMemRefType.getElementType());
376     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
377 
378     // Extract pointer to the underlying ranked descriptor and bitcast it to a
379     // memref<element_type> descriptor pointer to minimize the number of GEP
380     // operations.
381     UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
382     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
383     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
384         loc,
385         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
386                                    addressSpace),
387         underlyingRankedDesc);
388 
389     // Get pointer to offset field of memref<element_type> descriptor.
390     Type indexPtrTy = LLVM::LLVMPointerType::get(
391         getTypeConverter()->getIndexType(), addressSpace);
392     Value two = rewriter.create<LLVM::ConstantOp>(
393         loc, typeConverter->convertType(rewriter.getI32Type()),
394         rewriter.getI32IntegerAttr(2));
395     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
396         loc, indexPtrTy, scalarMemRefDescPtr,
397         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
398 
399     // The size value that we have to extract can be obtained using GEPop with
400     // `dimOp.index() + 1` index argument.
401     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
402         loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
403     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
404                                                  ValueRange({idxPlusOne}));
405     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
406   }
407 
getConstantDimIndex__anon7a9e10510111::DimOpLowering408   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
409     if (Optional<int64_t> idx = dimOp.getConstantIndex())
410       return idx;
411 
412     if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
413       return constantOp.getValue()
414           .cast<IntegerAttr>()
415           .getValue()
416           .getSExtValue();
417 
418     return llvm::None;
419   }
420 
extractSizeOfRankedMemRef__anon7a9e10510111::DimOpLowering421   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
422                                   OpAdaptor adaptor,
423                                   ConversionPatternRewriter &rewriter) const {
424     Location loc = dimOp.getLoc();
425 
426     // Take advantage if index is constant.
427     MemRefType memRefType = operandType.cast<MemRefType>();
428     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
429       int64_t i = *index;
430       if (memRefType.isDynamicDim(i)) {
431         // extract dynamic size from the memref descriptor.
432         MemRefDescriptor descriptor(adaptor.getSource());
433         return descriptor.size(rewriter, loc, i);
434       }
435       // Use constant for static size.
436       int64_t dimSize = memRefType.getDimSize(i);
437       return createIndexConstant(rewriter, loc, dimSize);
438     }
439     Value index = adaptor.getIndex();
440     int64_t rank = memRefType.getRank();
441     MemRefDescriptor memrefDescriptor(adaptor.getSource());
442     return memrefDescriptor.size(rewriter, loc, index, rank);
443   }
444 };
445 
446 /// Common base for load and store operations on MemRefs. Restricts the match
447 /// to supported MemRef types. Provides functionality to emit code accessing a
448 /// specific element of the underlying data buffer.
449 template <typename Derived>
450 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
451   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
452   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
453   using Base = LoadStoreOpLowering<Derived>;
454 
match__anon7a9e10510111::LoadStoreOpLowering455   LogicalResult match(Derived op) const override {
456     MemRefType type = op.getMemRefType();
457     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
458   }
459 };
460 
461 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
462 /// retried until it succeeds in atomically storing a new value into memory.
463 ///
464 ///      +---------------------------------+
465 ///      |   <code before the AtomicRMWOp> |
466 ///      |   <compute initial %loaded>     |
467 ///      |   cf.br loop(%loaded)              |
468 ///      +---------------------------------+
469 ///             |
470 ///  -------|   |
471 ///  |      v   v
472 ///  |   +--------------------------------+
473 ///  |   | loop(%loaded):                 |
474 ///  |   |   <body contents>              |
475 ///  |   |   %pair = cmpxchg              |
476 ///  |   |   %ok = %pair[0]               |
477 ///  |   |   %new = %pair[1]              |
478 ///  |   |   cf.cond_br %ok, end, loop(%new) |
479 ///  |   +--------------------------------+
480 ///  |          |        |
481 ///  |-----------        |
482 ///                      v
483 ///      +--------------------------------+
484 ///      | end:                           |
485 ///      |   <code after the AtomicRMWOp> |
486 ///      +--------------------------------+
487 ///
488 struct GenericAtomicRMWOpLowering
489     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
490   using Base::Base;
491 
492   LogicalResult
matchAndRewrite__anon7a9e10510111::GenericAtomicRMWOpLowering493   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
494                   ConversionPatternRewriter &rewriter) const override {
495     auto loc = atomicOp.getLoc();
496     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
497 
498     // Split the block into initial, loop, and ending parts.
499     auto *initBlock = rewriter.getInsertionBlock();
500     auto *loopBlock = rewriter.createBlock(
501         initBlock->getParent(), std::next(Region::iterator(initBlock)),
502         valueType, loc);
503     auto *endBlock = rewriter.createBlock(
504         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
505 
506     // Operations range to be moved to `endBlock`.
507     auto opsToMoveStart = atomicOp->getIterator();
508     auto opsToMoveEnd = initBlock->back().getIterator();
509 
510     // Compute the loaded value and branch to the loop block.
511     rewriter.setInsertionPointToEnd(initBlock);
512     auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
513     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
514                                         adaptor.getIndices(), rewriter);
515     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
516     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
517 
518     // Prepare the body of the loop block.
519     rewriter.setInsertionPointToStart(loopBlock);
520 
521     // Clone the GenericAtomicRMWOp region and extract the result.
522     auto loopArgument = loopBlock->getArgument(0);
523     BlockAndValueMapping mapping;
524     mapping.map(atomicOp.getCurrentValue(), loopArgument);
525     Block &entryBlock = atomicOp.body().front();
526     for (auto &nestedOp : entryBlock.without_terminator()) {
527       Operation *clone = rewriter.clone(nestedOp, mapping);
528       mapping.map(nestedOp.getResults(), clone->getResults());
529     }
530     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
531 
532     // Prepare the epilog of the loop block.
533     // Append the cmpxchg op to the end of the loop block.
534     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
535     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
536     auto boolType = IntegerType::get(rewriter.getContext(), 1);
537     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
538                                                      {valueType, boolType});
539     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
540         loc, pairType, dataPtr, loopArgument, result, successOrdering,
541         failureOrdering);
542     // Extract the %new_loaded and %ok values from the pair.
543     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
544         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
545     Value ok = rewriter.create<LLVM::ExtractValueOp>(
546         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
547 
548     // Conditionally branch to the end or back to the loop depending on %ok.
549     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
550                                     loopBlock, newLoaded);
551 
552     rewriter.setInsertionPointToEnd(endBlock);
553     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
554                  std::next(opsToMoveEnd), rewriter);
555 
556     // The 'result' of the atomic_rmw op is the newly loaded value.
557     rewriter.replaceOp(atomicOp, {newLoaded});
558 
559     return success();
560   }
561 
562 private:
563   // Clones a segment of ops [start, end) and erases the original.
moveOpsRange__anon7a9e10510111::GenericAtomicRMWOpLowering564   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
565                     Block::iterator start, Block::iterator end,
566                     ConversionPatternRewriter &rewriter) const {
567     BlockAndValueMapping mapping;
568     mapping.map(oldResult, newResult);
569     SmallVector<Operation *, 2> opsToErase;
570     for (auto it = start; it != end; ++it) {
571       rewriter.clone(*it, mapping);
572       opsToErase.push_back(&*it);
573     }
574     for (auto *it : opsToErase)
575       rewriter.eraseOp(it);
576   }
577 };
578 
579 /// Returns the LLVM type of the global variable given the memref type `type`.
convertGlobalMemrefTypeToLLVM(MemRefType type,LLVMTypeConverter & typeConverter)580 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
581                                           LLVMTypeConverter &typeConverter) {
582   // LLVM type for a global memref will be a multi-dimension array. For
583   // declarations or uninitialized global memrefs, we can potentially flatten
584   // this to a 1D array. However, for memref.global's with an initial value,
585   // we do not intend to flatten the ElementsAttribute when going from std ->
586   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
587   Type elementType = typeConverter.convertType(type.getElementType());
588   Type arrayTy = elementType;
589   // Shape has the outermost dim at index 0, so need to walk it backwards
590   for (int64_t dim : llvm::reverse(type.getShape()))
591     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
592   return arrayTy;
593 }
594 
595 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
596 struct GlobalMemrefOpLowering
597     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
598   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
599 
600   LogicalResult
matchAndRewrite__anon7a9e10510111::GlobalMemrefOpLowering601   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
602                   ConversionPatternRewriter &rewriter) const override {
603     MemRefType type = global.getType();
604     if (!isConvertibleAndHasIdentityMaps(type))
605       return failure();
606 
607     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
608 
609     LLVM::Linkage linkage =
610         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
611 
612     Attribute initialValue = nullptr;
613     if (!global.isExternal() && !global.isUninitialized()) {
614       auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
615       initialValue = elementsAttr;
616 
617       // For scalar memrefs, the global variable created is of the element type,
618       // so unpack the elements attribute to extract the value.
619       if (type.getRank() == 0)
620         initialValue = elementsAttr.getSplatValue<Attribute>();
621     }
622 
623     uint64_t alignment = global.getAlignment().value_or(0);
624 
625     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
626         global, arrayTy, global.getConstant(), linkage, global.getSymName(),
627         initialValue, alignment, type.getMemorySpaceAsInt());
628     if (!global.isExternal() && global.isUninitialized()) {
629       Block *blk = new Block();
630       newGlobal.getInitializerRegion().push_back(blk);
631       rewriter.setInsertionPointToStart(blk);
632       Value undef[] = {
633           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
634       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
635     }
636     return success();
637   }
638 };
639 
640 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
641 /// the first element stashed into the descriptor. This reuses
642 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
643 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
GetGlobalMemrefOpLowering__anon7a9e10510111::GetGlobalMemrefOpLowering644   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
645       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
646                                 converter) {}
647 
648   /// Buffer "allocation" for memref.get_global op is getting the address of
649   /// the global variable referenced.
allocateBuffer__anon7a9e10510111::GetGlobalMemrefOpLowering650   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
651                                           Location loc, Value sizeBytes,
652                                           Operation *op) const override {
653     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
654     MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
655     unsigned memSpace = type.getMemorySpaceAsInt();
656 
657     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
658     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
659         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace),
660         getGlobalOp.getName());
661 
662     // Get the address of the first element in the array by creating a GEP with
663     // the address of the GV as the base, and (rank + 1) number of 0 indices.
664     Type elementType = typeConverter->convertType(type.getElementType());
665     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
666 
667     SmallVector<Value> operands;
668     operands.insert(operands.end(), type.getRank() + 1,
669                     createIndexConstant(rewriter, loc, 0));
670     auto gep =
671         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
672 
673     // We do not expect the memref obtained using `memref.get_global` to be
674     // ever deallocated. Set the allocated pointer to be known bad value to
675     // help debug if that ever happens.
676     auto intPtrType = getIntPtrType(memSpace);
677     Value deadBeefConst =
678         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
679     auto deadBeefPtr =
680         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
681 
682     // Both allocated and aligned pointers are same. We could potentially stash
683     // a nullptr for the allocated pointer since we do not expect any dealloc.
684     return std::make_tuple(deadBeefPtr, gep);
685   }
686 };
687 
688 // Load operation is lowered to obtaining a pointer to the indexed element
689 // and loading it.
690 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
691   using Base::Base;
692 
693   LogicalResult
matchAndRewrite__anon7a9e10510111::LoadOpLowering694   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
695                   ConversionPatternRewriter &rewriter) const override {
696     auto type = loadOp.getMemRefType();
697 
698     Value dataPtr =
699         getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
700                              adaptor.getIndices(), rewriter);
701     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
702     return success();
703   }
704 };
705 
706 // Store operation is lowered to obtaining a pointer to the indexed element,
707 // and storing the given value to it.
708 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
709   using Base::Base;
710 
711   LogicalResult
matchAndRewrite__anon7a9e10510111::StoreOpLowering712   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
713                   ConversionPatternRewriter &rewriter) const override {
714     auto type = op.getMemRefType();
715 
716     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
717                                          adaptor.getIndices(), rewriter);
718     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr);
719     return success();
720   }
721 };
722 
723 // The prefetch operation is lowered in a way similar to the load operation
724 // except that the llvm.prefetch operation is used for replacement.
725 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
726   using Base::Base;
727 
728   LogicalResult
matchAndRewrite__anon7a9e10510111::PrefetchOpLowering729   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
730                   ConversionPatternRewriter &rewriter) const override {
731     auto type = prefetchOp.getMemRefType();
732     auto loc = prefetchOp.getLoc();
733 
734     Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
735                                          adaptor.getIndices(), rewriter);
736 
737     // Replace with llvm.prefetch.
738     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
739     auto isWrite = rewriter.create<LLVM::ConstantOp>(
740         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()));
741     auto localityHint = rewriter.create<LLVM::ConstantOp>(
742         loc, llvmI32Type,
743         rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint()));
744     auto isData = rewriter.create<LLVM::ConstantOp>(
745         loc, llvmI32Type,
746         rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache()));
747 
748     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
749                                                 localityHint, isData);
750     return success();
751   }
752 };
753 
754 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
755   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
756 
757   LogicalResult
matchAndRewrite__anon7a9e10510111::RankOpLowering758   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
759                   ConversionPatternRewriter &rewriter) const override {
760     Location loc = op.getLoc();
761     Type operandType = op.getMemref().getType();
762     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
763       UnrankedMemRefDescriptor desc(adaptor.getMemref());
764       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
765       return success();
766     }
767     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
768       rewriter.replaceOp(
769           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
770       return success();
771     }
772     return failure();
773   }
774 };
775 
776 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
777   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
778 
match__anon7a9e10510111::MemRefCastOpLowering779   LogicalResult match(memref::CastOp memRefCastOp) const override {
780     Type srcType = memRefCastOp.getOperand().getType();
781     Type dstType = memRefCastOp.getType();
782 
783     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
784     // used for type erasure. For now they must preserve underlying element type
785     // and require source and result type to have the same rank. Therefore,
786     // perform a sanity check that the underlying structs are the same. Once op
787     // semantics are relaxed we can revisit.
788     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
789       return success(typeConverter->convertType(srcType) ==
790                      typeConverter->convertType(dstType));
791 
792     // At least one of the operands is unranked type
793     assert(srcType.isa<UnrankedMemRefType>() ||
794            dstType.isa<UnrankedMemRefType>());
795 
796     // Unranked to unranked cast is disallowed
797     return !(srcType.isa<UnrankedMemRefType>() &&
798              dstType.isa<UnrankedMemRefType>())
799                ? success()
800                : failure();
801   }
802 
rewrite__anon7a9e10510111::MemRefCastOpLowering803   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
804                ConversionPatternRewriter &rewriter) const override {
805     auto srcType = memRefCastOp.getOperand().getType();
806     auto dstType = memRefCastOp.getType();
807     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
808     auto loc = memRefCastOp.getLoc();
809 
810     // For ranked/ranked case, just keep the original descriptor.
811     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
812       return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
813 
814     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
815       // Casting ranked to unranked memref type
816       // Set the rank in the destination from the memref type
817       // Allocate space on the stack and copy the src memref descriptor
818       // Set the ptr in the destination to the stack space
819       auto srcMemRefType = srcType.cast<MemRefType>();
820       int64_t rank = srcMemRefType.getRank();
821       // ptr = AllocaOp sizeof(MemRefDescriptor)
822       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
823           loc, adaptor.getSource(), rewriter);
824       // voidptr = BitCastOp srcType* to void*
825       auto voidPtr =
826           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
827               .getResult();
828       // rank = ConstantOp srcRank
829       auto rankVal = rewriter.create<LLVM::ConstantOp>(
830           loc, getIndexType(), rewriter.getIndexAttr(rank));
831       // undef = UndefOp
832       UnrankedMemRefDescriptor memRefDesc =
833           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
834       // d1 = InsertValueOp undef, rank, 0
835       memRefDesc.setRank(rewriter, loc, rankVal);
836       // d2 = InsertValueOp d1, voidptr, 1
837       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
838       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
839 
840     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
841       // Casting from unranked type to ranked.
842       // The operation is assumed to be doing a correct cast. If the destination
843       // type mismatches the unranked the type, it is undefined behavior.
844       UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
845       // ptr = ExtractValueOp src, 1
846       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
847       // castPtr = BitCastOp i8* to structTy*
848       auto castPtr =
849           rewriter
850               .create<LLVM::BitcastOp>(
851                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
852               .getResult();
853       // struct = LoadOp castPtr
854       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
855       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
856     } else {
857       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
858     }
859   }
860 };
861 
862 /// Pattern to lower a `memref.copy` to llvm.
863 ///
864 /// For memrefs with identity layouts, the copy is lowered to the llvm
865 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
866 /// to the generic `MemrefCopyFn`.
867 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
868   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
869 
870   LogicalResult
lowerToMemCopyIntrinsic__anon7a9e10510111::MemRefCopyOpLowering871   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
872                           ConversionPatternRewriter &rewriter) const {
873     auto loc = op.getLoc();
874     auto srcType = op.getSource().getType().dyn_cast<MemRefType>();
875 
876     MemRefDescriptor srcDesc(adaptor.getSource());
877 
878     // Compute number of elements.
879     Value numElements = rewriter.create<LLVM::ConstantOp>(
880         loc, getIndexType(), rewriter.getIndexAttr(1));
881     for (int pos = 0; pos < srcType.getRank(); ++pos) {
882       auto size = srcDesc.size(rewriter, loc, pos);
883       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
884     }
885 
886     // Get element size.
887     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
888     // Compute total.
889     Value totalSize =
890         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
891 
892     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
893     Value srcOffset = srcDesc.offset(rewriter, loc);
894     Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
895                                                 srcBasePtr, srcOffset);
896     MemRefDescriptor targetDesc(adaptor.getTarget());
897     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
898     Value targetOffset = targetDesc.offset(rewriter, loc);
899     Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
900                                                    targetBasePtr, targetOffset);
901     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
902         loc, typeConverter->convertType(rewriter.getI1Type()),
903         rewriter.getBoolAttr(false));
904     rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
905                                     isVolatile);
906     rewriter.eraseOp(op);
907 
908     return success();
909   }
910 
911   LogicalResult
lowerToMemCopyFunctionCall__anon7a9e10510111::MemRefCopyOpLowering912   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
913                              ConversionPatternRewriter &rewriter) const {
914     auto loc = op.getLoc();
915     auto srcType = op.getSource().getType().cast<BaseMemRefType>();
916     auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
917 
918     // First make sure we have an unranked memref descriptor representation.
919     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
920       auto rank = rewriter.create<LLVM::ConstantOp>(
921           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
922       auto *typeConverter = getTypeConverter();
923       auto ptr =
924           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
925       auto voidPtr =
926           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
927               .getResult();
928       auto unrankedType =
929           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
930       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
931                                             unrankedType,
932                                             ValueRange{rank, voidPtr});
933     };
934 
935     Value unrankedSource = srcType.hasRank()
936                                ? makeUnranked(adaptor.getSource(), srcType)
937                                : adaptor.getSource();
938     Value unrankedTarget = targetType.hasRank()
939                                ? makeUnranked(adaptor.getTarget(), targetType)
940                                : adaptor.getTarget();
941 
942     // Now promote the unranked descriptors to the stack.
943     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
944                                                  rewriter.getIndexAttr(1));
945     auto promote = [&](Value desc) {
946       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
947       auto allocated =
948           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
949       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
950       return allocated;
951     };
952 
953     auto sourcePtr = promote(unrankedSource);
954     auto targetPtr = promote(unrankedTarget);
955 
956     unsigned typeSize =
957         mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
958     auto elemSize = rewriter.create<LLVM::ConstantOp>(
959         loc, getIndexType(), rewriter.getIndexAttr(typeSize));
960     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
961         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
962     rewriter.create<LLVM::CallOp>(loc, copyFn,
963                                   ValueRange{elemSize, sourcePtr, targetPtr});
964     rewriter.eraseOp(op);
965 
966     return success();
967   }
968 
969   LogicalResult
matchAndRewrite__anon7a9e10510111::MemRefCopyOpLowering970   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
971                   ConversionPatternRewriter &rewriter) const override {
972     auto srcType = op.getSource().getType().cast<BaseMemRefType>();
973     auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
974 
975     auto isContiguousMemrefType = [](BaseMemRefType type) {
976       auto memrefType = type.dyn_cast<mlir::MemRefType>();
977       // We can use memcpy for memrefs if they have an identity layout or are
978       // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
979       // special case handled by memrefCopy.
980       return memrefType &&
981              (memrefType.getLayout().isIdentity() ||
982               (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
983                isStaticShapeAndContiguousRowMajor(memrefType)));
984     };
985 
986     if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
987       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
988 
989     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
990   }
991 };
992 
993 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
994 /// memref type. In unranked case, the fields are extracted from the underlying
995 /// ranked descriptor.
extractPointersAndOffset(Location loc,ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Value originalOperand,Value convertedOperand,Value * allocatedPtr,Value * alignedPtr,Value * offset=nullptr)996 static void extractPointersAndOffset(Location loc,
997                                      ConversionPatternRewriter &rewriter,
998                                      LLVMTypeConverter &typeConverter,
999                                      Value originalOperand,
1000                                      Value convertedOperand,
1001                                      Value *allocatedPtr, Value *alignedPtr,
1002                                      Value *offset = nullptr) {
1003   Type operandType = originalOperand.getType();
1004   if (operandType.isa<MemRefType>()) {
1005     MemRefDescriptor desc(convertedOperand);
1006     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1007     *alignedPtr = desc.alignedPtr(rewriter, loc);
1008     if (offset != nullptr)
1009       *offset = desc.offset(rewriter, loc);
1010     return;
1011   }
1012 
1013   unsigned memorySpace =
1014       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
1015   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
1016   Type llvmElementType = typeConverter.convertType(elementType);
1017   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
1018       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
1019 
1020   // Extract pointer to the underlying ranked memref descriptor and cast it to
1021   // ElemType**.
1022   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1023   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1024 
1025   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
1026       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
1027   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1028       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1029   if (offset != nullptr) {
1030     *offset = UnrankedMemRefDescriptor::offset(
1031         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1032   }
1033 }
1034 
1035 struct MemRefReinterpretCastOpLowering
1036     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1037   using ConvertOpToLLVMPattern<
1038       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1039 
1040   LogicalResult
matchAndRewrite__anon7a9e10510111::MemRefReinterpretCastOpLowering1041   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1042                   ConversionPatternRewriter &rewriter) const override {
1043     Type srcType = castOp.getSource().getType();
1044 
1045     Value descriptor;
1046     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1047                                                adaptor, &descriptor)))
1048       return failure();
1049     rewriter.replaceOp(castOp, {descriptor});
1050     return success();
1051   }
1052 
1053 private:
convertSourceMemRefToDescriptor__anon7a9e10510111::MemRefReinterpretCastOpLowering1054   LogicalResult convertSourceMemRefToDescriptor(
1055       ConversionPatternRewriter &rewriter, Type srcType,
1056       memref::ReinterpretCastOp castOp,
1057       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1058     MemRefType targetMemRefType =
1059         castOp.getResult().getType().cast<MemRefType>();
1060     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1061                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1062     if (!llvmTargetDescriptorTy)
1063       return failure();
1064 
1065     // Create descriptor.
1066     Location loc = castOp.getLoc();
1067     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1068 
1069     // Set allocated and aligned pointers.
1070     Value allocatedPtr, alignedPtr;
1071     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1072                              castOp.getSource(), adaptor.getSource(),
1073                              &allocatedPtr, &alignedPtr);
1074     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1075     desc.setAlignedPtr(rewriter, loc, alignedPtr);
1076 
1077     // Set offset.
1078     if (castOp.isDynamicOffset(0))
1079       desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1080     else
1081       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1082 
1083     // Set sizes and strides.
1084     unsigned dynSizeId = 0;
1085     unsigned dynStrideId = 0;
1086     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1087       if (castOp.isDynamicSize(i))
1088         desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1089       else
1090         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1091 
1092       if (castOp.isDynamicStride(i))
1093         desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1094       else
1095         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1096     }
1097     *descriptor = desc;
1098     return success();
1099   }
1100 };
1101 
1102 struct MemRefReshapeOpLowering
1103     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1104   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1105 
1106   LogicalResult
matchAndRewrite__anon7a9e10510111::MemRefReshapeOpLowering1107   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1108                   ConversionPatternRewriter &rewriter) const override {
1109     Type srcType = reshapeOp.getSource().getType();
1110 
1111     Value descriptor;
1112     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1113                                                adaptor, &descriptor)))
1114       return failure();
1115     rewriter.replaceOp(reshapeOp, {descriptor});
1116     return success();
1117   }
1118 
1119 private:
1120   LogicalResult
convertSourceMemRefToDescriptor__anon7a9e10510111::MemRefReshapeOpLowering1121   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1122                                   Type srcType, memref::ReshapeOp reshapeOp,
1123                                   memref::ReshapeOp::Adaptor adaptor,
1124                                   Value *descriptor) const {
1125     auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>();
1126     if (shapeMemRefType.hasStaticShape()) {
1127       MemRefType targetMemRefType =
1128           reshapeOp.getResult().getType().cast<MemRefType>();
1129       auto llvmTargetDescriptorTy =
1130           typeConverter->convertType(targetMemRefType)
1131               .dyn_cast_or_null<LLVM::LLVMStructType>();
1132       if (!llvmTargetDescriptorTy)
1133         return failure();
1134 
1135       // Create descriptor.
1136       Location loc = reshapeOp.getLoc();
1137       auto desc =
1138           MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1139 
1140       // Set allocated and aligned pointers.
1141       Value allocatedPtr, alignedPtr;
1142       extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1143                                reshapeOp.getSource(), adaptor.getSource(),
1144                                &allocatedPtr, &alignedPtr);
1145       desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1146       desc.setAlignedPtr(rewriter, loc, alignedPtr);
1147 
1148       // Extract the offset and strides from the type.
1149       int64_t offset;
1150       SmallVector<int64_t> strides;
1151       if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1152         return rewriter.notifyMatchFailure(
1153             reshapeOp, "failed to get stride and offset exprs");
1154 
1155       if (!isStaticStrideOrOffset(offset))
1156         return rewriter.notifyMatchFailure(reshapeOp,
1157                                            "dynamic offset is unsupported");
1158 
1159       desc.setConstantOffset(rewriter, loc, offset);
1160 
1161       assert(targetMemRefType.getLayout().isIdentity() &&
1162              "Identity layout map is a precondition of a valid reshape op");
1163 
1164       Value stride = nullptr;
1165       int64_t targetRank = targetMemRefType.getRank();
1166       for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1167         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1168           // If the stride for this dimension is dynamic, then use the product
1169           // of the sizes of the inner dimensions.
1170           stride = createIndexConstant(rewriter, loc, strides[i]);
1171         } else if (!stride) {
1172           // `stride` is null only in the first iteration of the loop.  However,
1173           // since the target memref has an identity layout, we can safely set
1174           // the innermost stride to 1.
1175           stride = createIndexConstant(rewriter, loc, 1);
1176         }
1177 
1178         Value dimSize;
1179         int64_t size = targetMemRefType.getDimSize(i);
1180         // If the size of this dimension is dynamic, then load it at runtime
1181         // from the shape operand.
1182         if (!ShapedType::isDynamic(size)) {
1183           dimSize = createIndexConstant(rewriter, loc, size);
1184         } else {
1185           Value shapeOp = reshapeOp.getShape();
1186           Value index = createIndexConstant(rewriter, loc, i);
1187           dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1188           Type indexType = getIndexType();
1189           if (dimSize.getType() != indexType)
1190             dimSize = typeConverter->materializeTargetConversion(
1191                 rewriter, loc, indexType, dimSize);
1192           assert(dimSize && "Invalid memref element type");
1193         }
1194 
1195         desc.setSize(rewriter, loc, i, dimSize);
1196         desc.setStride(rewriter, loc, i, stride);
1197 
1198         // Prepare the stride value for the next dimension.
1199         stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1200       }
1201 
1202       *descriptor = desc;
1203       return success();
1204     }
1205 
1206     // The shape is a rank-1 tensor with unknown length.
1207     Location loc = reshapeOp.getLoc();
1208     MemRefDescriptor shapeDesc(adaptor.getShape());
1209     Value resultRank = shapeDesc.size(rewriter, loc, 0);
1210 
1211     // Extract address space and element type.
1212     auto targetType =
1213         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1214     unsigned addressSpace = targetType.getMemorySpaceAsInt();
1215     Type elementType = targetType.getElementType();
1216 
1217     // Create the unranked memref descriptor that holds the ranked one. The
1218     // inner descriptor is allocated on stack.
1219     auto targetDesc = UnrankedMemRefDescriptor::undef(
1220         rewriter, loc, typeConverter->convertType(targetType));
1221     targetDesc.setRank(rewriter, loc, resultRank);
1222     SmallVector<Value, 4> sizes;
1223     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1224                                            targetDesc, sizes);
1225     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1226         loc, getVoidPtrType(), sizes.front(), llvm::None);
1227     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1228 
1229     // Extract pointers and offset from the source memref.
1230     Value allocatedPtr, alignedPtr, offset;
1231     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1232                              reshapeOp.getSource(), adaptor.getSource(),
1233                              &allocatedPtr, &alignedPtr, &offset);
1234 
1235     // Set pointers and offset.
1236     Type llvmElementType = typeConverter->convertType(elementType);
1237     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1238         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1239     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1240                                               elementPtrPtrType, allocatedPtr);
1241     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1242                                             underlyingDescPtr,
1243                                             elementPtrPtrType, alignedPtr);
1244     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1245                                         underlyingDescPtr, elementPtrPtrType,
1246                                         offset);
1247 
1248     // Use the offset pointer as base for further addressing. Copy over the new
1249     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1250     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1251         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1252         elementPtrPtrType);
1253     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1254         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1255     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1256     Value oneIndex = createIndexConstant(rewriter, loc, 1);
1257     Value resultRankMinusOne =
1258         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1259 
1260     Block *initBlock = rewriter.getInsertionBlock();
1261     Type indexType = getTypeConverter()->getIndexType();
1262     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1263 
1264     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1265                                             {indexType, indexType}, {loc, loc});
1266 
1267     // Move the remaining initBlock ops to condBlock.
1268     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1269     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1270 
1271     rewriter.setInsertionPointToEnd(initBlock);
1272     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1273                                 condBlock);
1274     rewriter.setInsertionPointToStart(condBlock);
1275     Value indexArg = condBlock->getArgument(0);
1276     Value strideArg = condBlock->getArgument(1);
1277 
1278     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1279     Value pred = rewriter.create<LLVM::ICmpOp>(
1280         loc, IntegerType::get(rewriter.getContext(), 1),
1281         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1282 
1283     Block *bodyBlock =
1284         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1285     rewriter.setInsertionPointToStart(bodyBlock);
1286 
1287     // Copy size from shape to descriptor.
1288     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1289     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1290         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
1291     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1292     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1293                                       targetSizesBase, indexArg, size);
1294 
1295     // Write stride value and compute next one.
1296     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1297                                         targetStridesBase, indexArg, strideArg);
1298     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1299 
1300     // Decrement loop counter and branch back.
1301     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1302     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1303                                 condBlock);
1304 
1305     Block *remainder =
1306         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1307 
1308     // Hook up the cond exit to the remainder.
1309     rewriter.setInsertionPointToEnd(condBlock);
1310     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1311                                     llvm::None);
1312 
1313     // Reset position to beginning of new remainder block.
1314     rewriter.setInsertionPointToStart(remainder);
1315 
1316     *descriptor = targetDesc;
1317     return success();
1318   }
1319 };
1320 
1321 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1322 /// `Value`s.
getAsValues(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<OpFoldResult> valueOrAttrVec)1323 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1324                                       Type &llvmIndexType,
1325                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1326   return llvm::to_vector<4>(
1327       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1328         if (auto attr = value.dyn_cast<Attribute>())
1329           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1330         return value.get<Value>();
1331       }));
1332 }
1333 
1334 /// Compute a map that for a given dimension of the expanded type gives the
1335 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1336 /// the `reassocation` maps.
1337 static DenseMap<int64_t, int64_t>
getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation)1338 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1339   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1340   for (auto &en : enumerate(reassociation)) {
1341     for (auto dim : en.value())
1342       expandedDimToCollapsedDim[dim] = en.index();
1343   }
1344   return expandedDimToCollapsedDim;
1345 }
1346 
1347 static OpFoldResult
getExpandedOutputDimSize(OpBuilder & b,Location loc,Type & llvmIndexType,int64_t outDimIndex,ArrayRef<int64_t> outStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> inStaticShape,ArrayRef<ReassociationIndices> reassocation,DenseMap<int64_t,int64_t> & outDimToInDimMap)1348 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1349                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1350                          MemRefDescriptor &inDesc,
1351                          ArrayRef<int64_t> inStaticShape,
1352                          ArrayRef<ReassociationIndices> reassocation,
1353                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1354   int64_t outDimSize = outStaticShape[outDimIndex];
1355   if (!ShapedType::isDynamic(outDimSize))
1356     return b.getIndexAttr(outDimSize);
1357 
1358   // Calculate the multiplication of all the out dim sizes except the
1359   // current dim.
1360   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1361   int64_t otherDimSizesMul = 1;
1362   for (auto otherDimIndex : reassocation[inDimIndex]) {
1363     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1364       continue;
1365     int64_t otherDimSize = outStaticShape[otherDimIndex];
1366     assert(!ShapedType::isDynamic(otherDimSize) &&
1367            "single dimension cannot be expanded into multiple dynamic "
1368            "dimensions");
1369     otherDimSizesMul *= otherDimSize;
1370   }
1371 
1372   // outDimSize = inDimSize / otherOutDimSizesMul
1373   int64_t inDimSize = inStaticShape[inDimIndex];
1374   Value inDimSizeDynamic =
1375       ShapedType::isDynamic(inDimSize)
1376           ? inDesc.size(b, loc, inDimIndex)
1377           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1378                                        b.getIndexAttr(inDimSize));
1379   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1380       loc, inDimSizeDynamic,
1381       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1382                                  b.getIndexAttr(otherDimSizesMul)));
1383   return outDimSizeDynamic;
1384 }
1385 
getCollapsedOutputDimSize(OpBuilder & b,Location loc,Type & llvmIndexType,int64_t outDimIndex,int64_t outDimSize,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<ReassociationIndices> reassocation)1386 static OpFoldResult getCollapsedOutputDimSize(
1387     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1388     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1389     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1390   if (!ShapedType::isDynamic(outDimSize))
1391     return b.getIndexAttr(outDimSize);
1392 
1393   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1394   Value outDimSizeDynamic = c1;
1395   for (auto inDimIndex : reassocation[outDimIndex]) {
1396     int64_t inDimSize = inStaticShape[inDimIndex];
1397     Value inDimSizeDynamic =
1398         ShapedType::isDynamic(inDimSize)
1399             ? inDesc.size(b, loc, inDimIndex)
1400             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1401                                          b.getIndexAttr(inDimSize));
1402     outDimSizeDynamic =
1403         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1404   }
1405   return outDimSizeDynamic;
1406 }
1407 
1408 static SmallVector<OpFoldResult, 4>
getCollapsedOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassociation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1409 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1410                         ArrayRef<ReassociationIndices> reassociation,
1411                         ArrayRef<int64_t> inStaticShape,
1412                         MemRefDescriptor &inDesc,
1413                         ArrayRef<int64_t> outStaticShape) {
1414   return llvm::to_vector<4>(llvm::map_range(
1415       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1416         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1417                                          outStaticShape[outDimIndex],
1418                                          inStaticShape, inDesc, reassociation);
1419       }));
1420 }
1421 
1422 static SmallVector<OpFoldResult, 4>
getExpandedOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassociation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1423 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1424                        ArrayRef<ReassociationIndices> reassociation,
1425                        ArrayRef<int64_t> inStaticShape,
1426                        MemRefDescriptor &inDesc,
1427                        ArrayRef<int64_t> outStaticShape) {
1428   DenseMap<int64_t, int64_t> outDimToInDimMap =
1429       getExpandedDimToCollapsedDimMap(reassociation);
1430   return llvm::to_vector<4>(llvm::map_range(
1431       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1432         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1433                                         outStaticShape, inDesc, inStaticShape,
1434                                         reassociation, outDimToInDimMap);
1435       }));
1436 }
1437 
1438 static SmallVector<Value>
getDynamicOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassociation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1439 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1440                       ArrayRef<ReassociationIndices> reassociation,
1441                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1442                       ArrayRef<int64_t> outStaticShape) {
1443   return outStaticShape.size() < inStaticShape.size()
1444              ? getAsValues(b, loc, llvmIndexType,
1445                            getCollapsedOutputShape(b, loc, llvmIndexType,
1446                                                    reassociation, inStaticShape,
1447                                                    inDesc, outStaticShape))
1448              : getAsValues(b, loc, llvmIndexType,
1449                            getExpandedOutputShape(b, loc, llvmIndexType,
1450                                                   reassociation, inStaticShape,
1451                                                   inDesc, outStaticShape));
1452 }
1453 
fillInStridesForExpandedMemDescriptor(OpBuilder & b,Location loc,MemRefType srcType,MemRefDescriptor & srcDesc,MemRefDescriptor & dstDesc,ArrayRef<ReassociationIndices> reassociation)1454 static void fillInStridesForExpandedMemDescriptor(
1455     OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
1456     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1457   // See comments for computeExpandedLayoutMap for details on how the strides
1458   // are calculated.
1459   for (auto &en : llvm::enumerate(reassociation)) {
1460     auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
1461     for (auto dstIndex : llvm::reverse(en.value())) {
1462       dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
1463       Value size = dstDesc.size(b, loc, dstIndex);
1464       currentStrideToExpand =
1465           b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
1466     }
1467   }
1468 }
1469 
fillInStridesForCollapsedMemDescriptor(ConversionPatternRewriter & rewriter,Location loc,Operation * op,TypeConverter * typeConverter,MemRefType srcType,MemRefDescriptor & srcDesc,MemRefDescriptor & dstDesc,ArrayRef<ReassociationIndices> reassociation)1470 static void fillInStridesForCollapsedMemDescriptor(
1471     ConversionPatternRewriter &rewriter, Location loc, Operation *op,
1472     TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
1473     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1474   // See comments for computeCollapsedLayoutMap for details on how the strides
1475   // are calculated.
1476   auto srcShape = srcType.getShape();
1477   for (auto &en : llvm::enumerate(reassociation)) {
1478     rewriter.setInsertionPoint(op);
1479     auto dstIndex = en.index();
1480     ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value());
1481     while (srcShape[ref.back()] == 1 && ref.size() > 1)
1482       ref = ref.drop_back();
1483     if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1484       dstDesc.setStride(rewriter, loc, dstIndex,
1485                         srcDesc.stride(rewriter, loc, ref.back()));
1486     } else {
1487       // Iterate over the source strides in reverse order. Skip over the
1488       // dimensions whose size is 1.
1489       // TODO: we should take the minimum stride in the reassociation group
1490       // instead of just the first where the dimension is not 1.
1491       //
1492       // +------------------------------------------------------+
1493       // | curEntry:                                            |
1494       // |   %srcStride = strides[srcIndex]                     |
1495       // |   %neOne = cmp sizes[srcIndex],1                     +--+
1496       // |   cf.cond_br %neOne, continue(%srcStride), nextEntry |  |
1497       // +-------------------------+----------------------------+  |
1498       //                           |                               |
1499       //                           v                               |
1500       //            +-----------------------------+                |
1501       //            | nextEntry:                  |                |
1502       //            |   ...                       +---+            |
1503       //            +--------------+--------------+   |            |
1504       //                           |                  |            |
1505       //                           v                  |            |
1506       //            +-----------------------------+   |            |
1507       //            | nextEntry:                  |   |            |
1508       //            |   ...                       |   |            |
1509       //            +--------------+--------------+   |   +--------+
1510       //                           |                  |   |
1511       //                           v                  v   v
1512       //   +--------------------------------------------------+
1513       //   | continue(%newStride):                            |
1514       //   |   %newMemRefDes = setStride(%newStride,dstIndex) |
1515       //   +--------------------------------------------------+
1516       OpBuilder::InsertionGuard guard(rewriter);
1517       Block *initBlock = rewriter.getInsertionBlock();
1518       Block *continueBlock =
1519           rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
1520       continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
1521       rewriter.setInsertionPointToStart(continueBlock);
1522       dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
1523 
1524       Block *curEntryBlock = initBlock;
1525       Block *nextEntryBlock;
1526       for (auto srcIndex : llvm::reverse(ref)) {
1527         if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
1528           continue;
1529         rewriter.setInsertionPointToEnd(curEntryBlock);
1530         Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
1531         if (srcIndex == ref.front()) {
1532           rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
1533           break;
1534         }
1535         Value one = rewriter.create<LLVM::ConstantOp>(
1536             loc, typeConverter->convertType(rewriter.getI64Type()),
1537             rewriter.getI32IntegerAttr(1));
1538         Value predNeOne = rewriter.create<LLVM::ICmpOp>(
1539             loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
1540             one);
1541         {
1542           OpBuilder::InsertionGuard guard(rewriter);
1543           nextEntryBlock = rewriter.createBlock(
1544               initBlock->getParent(), Region::iterator(continueBlock), {});
1545         }
1546         rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
1547                                         srcStride, nextEntryBlock, llvm::None);
1548         curEntryBlock = nextEntryBlock;
1549       }
1550     }
1551   }
1552 }
1553 
fillInDynamicStridesForMemDescriptor(ConversionPatternRewriter & b,Location loc,Operation * op,TypeConverter * typeConverter,MemRefType srcType,MemRefType dstType,MemRefDescriptor & srcDesc,MemRefDescriptor & dstDesc,ArrayRef<ReassociationIndices> reassociation)1554 static void fillInDynamicStridesForMemDescriptor(
1555     ConversionPatternRewriter &b, Location loc, Operation *op,
1556     TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
1557     MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
1558     ArrayRef<ReassociationIndices> reassociation) {
1559   if (srcType.getRank() > dstType.getRank())
1560     fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
1561                                            srcDesc, dstDesc, reassociation);
1562   else
1563     fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
1564                                           reassociation);
1565 }
1566 
1567 // ReshapeOp creates a new view descriptor of the proper rank.
1568 // For now, the only conversion supported is for target MemRef with static sizes
1569 // and strides.
1570 template <typename ReshapeOp>
1571 class ReassociatingReshapeOpConversion
1572     : public ConvertOpToLLVMPattern<ReshapeOp> {
1573 public:
1574   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1575   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1576 
1577   LogicalResult
matchAndRewrite(ReshapeOp reshapeOp,typename ReshapeOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const1578   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1579                   ConversionPatternRewriter &rewriter) const override {
1580     MemRefType dstType = reshapeOp.getResultType();
1581     MemRefType srcType = reshapeOp.getSrcType();
1582 
1583     int64_t offset;
1584     SmallVector<int64_t, 4> strides;
1585     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1586       return rewriter.notifyMatchFailure(
1587           reshapeOp, "failed to get stride and offset exprs");
1588     }
1589 
1590     MemRefDescriptor srcDesc(adaptor.getSrc());
1591     Location loc = reshapeOp->getLoc();
1592     auto dstDesc = MemRefDescriptor::undef(
1593         rewriter, loc, this->typeConverter->convertType(dstType));
1594     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1595     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1596     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1597 
1598     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1599     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1600     Type llvmIndexType =
1601         this->typeConverter->convertType(rewriter.getIndexType());
1602     SmallVector<Value> dstShape = getDynamicOutputShape(
1603         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1604         srcStaticShape, srcDesc, dstStaticShape);
1605     for (auto &en : llvm::enumerate(dstShape))
1606       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1607 
1608     if (llvm::all_of(strides, isStaticStrideOrOffset)) {
1609       for (auto &en : llvm::enumerate(strides))
1610         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1611     } else if (srcType.getLayout().isIdentity() &&
1612                dstType.getLayout().isIdentity()) {
1613       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1614                                                    rewriter.getIndexAttr(1));
1615       Value stride = c1;
1616       for (auto dimIndex :
1617            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1618         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1619         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1620       }
1621     } else {
1622       // There could be mixed static/dynamic strides. For simplicity, we
1623       // recompute all strides if there is at least one dynamic stride.
1624       fillInDynamicStridesForMemDescriptor(
1625           rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
1626           srcDesc, dstDesc, reshapeOp.getReassociationIndices());
1627     }
1628     rewriter.replaceOp(reshapeOp, {dstDesc});
1629     return success();
1630   }
1631 };
1632 
1633 /// Conversion pattern that transforms a subview op into:
1634 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1635 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1636 ///      and stride.
1637 /// The subview op is replaced by the descriptor.
1638 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1639   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1640 
1641   LogicalResult
matchAndRewrite__anon7a9e10510111::SubViewOpLowering1642   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1643                   ConversionPatternRewriter &rewriter) const override {
1644     auto loc = subViewOp.getLoc();
1645 
1646     auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>();
1647     auto sourceElementTy =
1648         typeConverter->convertType(sourceMemRefType.getElementType());
1649 
1650     auto viewMemRefType = subViewOp.getType();
1651     auto inferredType =
1652         memref::SubViewOp::inferResultType(
1653             subViewOp.getSourceType(),
1654             extractFromI64ArrayAttr(subViewOp.getStaticOffsets()),
1655             extractFromI64ArrayAttr(subViewOp.getStaticSizes()),
1656             extractFromI64ArrayAttr(subViewOp.getStaticStrides()))
1657             .cast<MemRefType>();
1658     auto targetElementTy =
1659         typeConverter->convertType(viewMemRefType.getElementType());
1660     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1661     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1662         !LLVM::isCompatibleType(sourceElementTy) ||
1663         !LLVM::isCompatibleType(targetElementTy) ||
1664         !LLVM::isCompatibleType(targetDescTy))
1665       return failure();
1666 
1667     // Extract the offset and strides from the type.
1668     int64_t offset;
1669     SmallVector<int64_t, 4> strides;
1670     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1671     if (failed(successStrides))
1672       return failure();
1673 
1674     // Create the descriptor.
1675     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1676       return failure();
1677     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1678     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1679 
1680     // Copy the buffer pointer from the old descriptor to the new one.
1681     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1682     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1683         loc,
1684         LLVM::LLVMPointerType::get(targetElementTy,
1685                                    viewMemRefType.getMemorySpaceAsInt()),
1686         extracted);
1687     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1688 
1689     // Copy the aligned pointer from the old descriptor to the new one.
1690     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1691     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1692         loc,
1693         LLVM::LLVMPointerType::get(targetElementTy,
1694                                    viewMemRefType.getMemorySpaceAsInt()),
1695         extracted);
1696     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1697 
1698     size_t inferredShapeRank = inferredType.getRank();
1699     size_t resultShapeRank = viewMemRefType.getRank();
1700 
1701     // Extract strides needed to compute offset.
1702     SmallVector<Value, 4> strideValues;
1703     strideValues.reserve(inferredShapeRank);
1704     for (unsigned i = 0; i < inferredShapeRank; ++i)
1705       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1706 
1707     // Offset.
1708     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1709     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1710       targetMemRef.setConstantOffset(rewriter, loc, offset);
1711     } else {
1712       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1713       // `inferredShapeRank` may be larger than the number of offset operands
1714       // because of trailing semantics. In this case, the offset is guaranteed
1715       // to be interpreted as 0 and we can just skip the extra dimensions.
1716       for (unsigned i = 0, e = std::min(inferredShapeRank,
1717                                         subViewOp.getMixedOffsets().size());
1718            i < e; ++i) {
1719         Value offset =
1720             // TODO: need OpFoldResult ODS adaptor to clean this up.
1721             subViewOp.isDynamicOffset(i)
1722                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1723                 : rewriter.create<LLVM::ConstantOp>(
1724                       loc, llvmIndexType,
1725                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1726         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1727         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1728       }
1729       targetMemRef.setOffset(rewriter, loc, baseOffset);
1730     }
1731 
1732     // Update sizes and strides.
1733     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1734     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1735     assert(mixedSizes.size() == mixedStrides.size() &&
1736            "expected sizes and strides of equal length");
1737     llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1738     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1739          i >= 0 && j >= 0; --i) {
1740       if (unusedDims.test(i))
1741         continue;
1742 
1743       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1744       // In this case, the size is guaranteed to be interpreted as Dim and the
1745       // stride as 1.
1746       Value size, stride;
1747       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1748         // If the static size is available, use it directly. This is similar to
1749         // the folding of dim(constant-op) but removes the need for dim to be
1750         // aware of LLVM constants and for this pass to be aware of std
1751         // constants.
1752         int64_t staticSize =
1753             subViewOp.getSource().getType().cast<MemRefType>().getShape()[i];
1754         if (staticSize != ShapedType::kDynamicSize) {
1755           size = rewriter.create<LLVM::ConstantOp>(
1756               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1757         } else {
1758           Value pos = rewriter.create<LLVM::ConstantOp>(
1759               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1760           Value dim =
1761               rewriter.create<memref::DimOp>(loc, subViewOp.getSource(), pos);
1762           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1763               loc, llvmIndexType, dim);
1764           size = cast.getResult(0);
1765         }
1766         stride = rewriter.create<LLVM::ConstantOp>(
1767             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1768       } else {
1769         // TODO: need OpFoldResult ODS adaptor to clean this up.
1770         size =
1771             subViewOp.isDynamicSize(i)
1772                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1773                 : rewriter.create<LLVM::ConstantOp>(
1774                       loc, llvmIndexType,
1775                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1776         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1777           stride = rewriter.create<LLVM::ConstantOp>(
1778               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1779         } else {
1780           stride =
1781               subViewOp.isDynamicStride(i)
1782                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1783                   : rewriter.create<LLVM::ConstantOp>(
1784                         loc, llvmIndexType,
1785                         rewriter.getI64IntegerAttr(
1786                             subViewOp.getStaticStride(i)));
1787           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1788         }
1789       }
1790       targetMemRef.setSize(rewriter, loc, j, size);
1791       targetMemRef.setStride(rewriter, loc, j, stride);
1792       j--;
1793     }
1794 
1795     rewriter.replaceOp(subViewOp, {targetMemRef});
1796     return success();
1797   }
1798 };
1799 
1800 /// Conversion pattern that transforms a transpose op into:
1801 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1802 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1803 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1804 ///      and stride. Size and stride are permutations of the original values.
1805 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1806 /// The transpose op is replaced by the alloca'ed pointer.
1807 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1808 public:
1809   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1810 
1811   LogicalResult
matchAndRewrite(memref::TransposeOp transposeOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1812   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1813                   ConversionPatternRewriter &rewriter) const override {
1814     auto loc = transposeOp.getLoc();
1815     MemRefDescriptor viewMemRef(adaptor.getIn());
1816 
1817     // No permutation, early exit.
1818     if (transposeOp.getPermutation().isIdentity())
1819       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1820 
1821     auto targetMemRef = MemRefDescriptor::undef(
1822         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1823 
1824     // Copy the base and aligned pointers from the old descriptor to the new
1825     // one.
1826     targetMemRef.setAllocatedPtr(rewriter, loc,
1827                                  viewMemRef.allocatedPtr(rewriter, loc));
1828     targetMemRef.setAlignedPtr(rewriter, loc,
1829                                viewMemRef.alignedPtr(rewriter, loc));
1830 
1831     // Copy the offset pointer from the old descriptor to the new one.
1832     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1833 
1834     // Iterate over the dimensions and apply size/stride permutation.
1835     for (const auto &en :
1836          llvm::enumerate(transposeOp.getPermutation().getResults())) {
1837       int sourcePos = en.index();
1838       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1839       targetMemRef.setSize(rewriter, loc, targetPos,
1840                            viewMemRef.size(rewriter, loc, sourcePos));
1841       targetMemRef.setStride(rewriter, loc, targetPos,
1842                              viewMemRef.stride(rewriter, loc, sourcePos));
1843     }
1844 
1845     rewriter.replaceOp(transposeOp, {targetMemRef});
1846     return success();
1847   }
1848 };
1849 
1850 /// Conversion pattern that transforms an op into:
1851 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1852 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1853 ///      and stride.
1854 /// The view op is replaced by the descriptor.
1855 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1856   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1857 
1858   // Build and return the value for the idx^th shape dimension, either by
1859   // returning the constant shape dimension or counting the proper dynamic size.
getSize__anon7a9e10510111::ViewOpLowering1860   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1861                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1862                 unsigned idx) const {
1863     assert(idx < shape.size());
1864     if (!ShapedType::isDynamic(shape[idx]))
1865       return createIndexConstant(rewriter, loc, shape[idx]);
1866     // Count the number of dynamic dims in range [0, idx]
1867     unsigned nDynamic =
1868         llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1869     return dynamicSizes[nDynamic];
1870   }
1871 
1872   // Build and return the idx^th stride, either by returning the constant stride
1873   // or by computing the dynamic stride from the current `runningStride` and
1874   // `nextSize`. The caller should keep a running stride and update it with the
1875   // result returned by this function.
getStride__anon7a9e10510111::ViewOpLowering1876   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1877                   ArrayRef<int64_t> strides, Value nextSize,
1878                   Value runningStride, unsigned idx) const {
1879     assert(idx < strides.size());
1880     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1881       return createIndexConstant(rewriter, loc, strides[idx]);
1882     if (nextSize)
1883       return runningStride
1884                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1885                  : nextSize;
1886     assert(!runningStride);
1887     return createIndexConstant(rewriter, loc, 1);
1888   }
1889 
1890   LogicalResult
matchAndRewrite__anon7a9e10510111::ViewOpLowering1891   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1892                   ConversionPatternRewriter &rewriter) const override {
1893     auto loc = viewOp.getLoc();
1894 
1895     auto viewMemRefType = viewOp.getType();
1896     auto targetElementTy =
1897         typeConverter->convertType(viewMemRefType.getElementType());
1898     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1899     if (!targetDescTy || !targetElementTy ||
1900         !LLVM::isCompatibleType(targetElementTy) ||
1901         !LLVM::isCompatibleType(targetDescTy))
1902       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1903              failure();
1904 
1905     int64_t offset;
1906     SmallVector<int64_t, 4> strides;
1907     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1908     if (failed(successStrides))
1909       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1910     assert(offset == 0 && "expected offset to be 0");
1911 
1912     // Target memref must be contiguous in memory (innermost stride is 1), or
1913     // empty (special case when at least one of the memref dimensions is 0).
1914     if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1915       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1916              failure();
1917 
1918     // Create the descriptor.
1919     MemRefDescriptor sourceMemRef(adaptor.getSource());
1920     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1921 
1922     // Field 1: Copy the allocated pointer, used for malloc/free.
1923     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1924     auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
1925     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1926         loc,
1927         LLVM::LLVMPointerType::get(targetElementTy,
1928                                    srcMemRefType.getMemorySpaceAsInt()),
1929         allocatedPtr);
1930     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1931 
1932     // Field 2: Copy the actual aligned pointer to payload.
1933     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1934     alignedPtr = rewriter.create<LLVM::GEPOp>(
1935         loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift());
1936     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1937         loc,
1938         LLVM::LLVMPointerType::get(targetElementTy,
1939                                    srcMemRefType.getMemorySpaceAsInt()),
1940         alignedPtr);
1941     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1942 
1943     // Field 3: The offset in the resulting type must be 0. This is because of
1944     // the type change: an offset on srcType* may not be expressible as an
1945     // offset on dstType*.
1946     targetMemRef.setOffset(rewriter, loc,
1947                            createIndexConstant(rewriter, loc, offset));
1948 
1949     // Early exit for 0-D corner case.
1950     if (viewMemRefType.getRank() == 0)
1951       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1952 
1953     // Fields 4 and 5: Update sizes and strides.
1954     Value stride = nullptr, nextSize = nullptr;
1955     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1956       // Update size.
1957       Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1958                            adaptor.getSizes(), i);
1959       targetMemRef.setSize(rewriter, loc, i, size);
1960       // Update stride.
1961       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1962       targetMemRef.setStride(rewriter, loc, i, stride);
1963       nextSize = size;
1964     }
1965 
1966     rewriter.replaceOp(viewOp, {targetMemRef});
1967     return success();
1968   }
1969 };
1970 
1971 //===----------------------------------------------------------------------===//
1972 // AtomicRMWOpLowering
1973 //===----------------------------------------------------------------------===//
1974 
1975 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1976 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1977 static Optional<LLVM::AtomicBinOp>
matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp)1978 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1979   switch (atomicOp.getKind()) {
1980   case arith::AtomicRMWKind::addf:
1981     return LLVM::AtomicBinOp::fadd;
1982   case arith::AtomicRMWKind::addi:
1983     return LLVM::AtomicBinOp::add;
1984   case arith::AtomicRMWKind::assign:
1985     return LLVM::AtomicBinOp::xchg;
1986   case arith::AtomicRMWKind::maxs:
1987     return LLVM::AtomicBinOp::max;
1988   case arith::AtomicRMWKind::maxu:
1989     return LLVM::AtomicBinOp::umax;
1990   case arith::AtomicRMWKind::mins:
1991     return LLVM::AtomicBinOp::min;
1992   case arith::AtomicRMWKind::minu:
1993     return LLVM::AtomicBinOp::umin;
1994   case arith::AtomicRMWKind::ori:
1995     return LLVM::AtomicBinOp::_or;
1996   case arith::AtomicRMWKind::andi:
1997     return LLVM::AtomicBinOp::_and;
1998   default:
1999     return llvm::None;
2000   }
2001   llvm_unreachable("Invalid AtomicRMWKind");
2002 }
2003 
2004 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
2005   using Base::Base;
2006 
2007   LogicalResult
matchAndRewrite__anon7a9e10510111::AtomicRMWOpLowering2008   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
2009                   ConversionPatternRewriter &rewriter) const override {
2010     if (failed(match(atomicOp)))
2011       return failure();
2012     auto maybeKind = matchSimpleAtomicOp(atomicOp);
2013     if (!maybeKind)
2014       return failure();
2015     auto resultType = adaptor.getValue().getType();
2016     auto memRefType = atomicOp.getMemRefType();
2017     auto dataPtr =
2018         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
2019                              adaptor.getIndices(), rewriter);
2020     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
2021         atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
2022         LLVM::AtomicOrdering::acq_rel);
2023     return success();
2024   }
2025 };
2026 
2027 } // namespace
2028 
populateMemRefToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)2029 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
2030                                                   RewritePatternSet &patterns) {
2031   // clang-format off
2032   patterns.add<
2033       AllocaOpLowering,
2034       AllocaScopeOpLowering,
2035       AtomicRMWOpLowering,
2036       AssumeAlignmentOpLowering,
2037       DimOpLowering,
2038       GenericAtomicRMWOpLowering,
2039       GlobalMemrefOpLowering,
2040       GetGlobalMemrefOpLowering,
2041       LoadOpLowering,
2042       MemRefCastOpLowering,
2043       MemRefCopyOpLowering,
2044       MemRefReinterpretCastOpLowering,
2045       MemRefReshapeOpLowering,
2046       PrefetchOpLowering,
2047       RankOpLowering,
2048       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2049       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2050       StoreOpLowering,
2051       SubViewOpLowering,
2052       TransposeOpLowering,
2053       ViewOpLowering>(converter);
2054   // clang-format on
2055   auto allocLowering = converter.getOptions().allocLowering;
2056   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
2057     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
2058   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2059     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
2060 }
2061 
2062 namespace {
2063 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
2064   MemRefToLLVMPass() = default;
2065 
runOnOperation__anon7a9e10510811::MemRefToLLVMPass2066   void runOnOperation() override {
2067     Operation *op = getOperation();
2068     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2069     LowerToLLVMOptions options(&getContext(),
2070                                dataLayoutAnalysis.getAtOrAbove(op));
2071     options.allocLowering =
2072         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
2073                          : LowerToLLVMOptions::AllocLowering::Malloc);
2074 
2075     options.useGenericFunctions = useGenericFunctions;
2076 
2077     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2078       options.overrideIndexBitwidth(indexBitwidth);
2079 
2080     LLVMTypeConverter typeConverter(&getContext(), options,
2081                                     &dataLayoutAnalysis);
2082     RewritePatternSet patterns(&getContext());
2083     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
2084     LLVMConversionTarget target(getContext());
2085     target.addLegalOp<func::FuncOp>();
2086     if (failed(applyPartialConversion(op, target, std::move(patterns))))
2087       signalPassFailure();
2088   }
2089 };
2090 } // namespace
2091 
createMemRefToLLVMPass()2092 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
2093   return std::make_unique<MemRefToLLVMPass>();
2094 }
2095