//===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "../PassDetail.h" #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "llvm/ADT/SmallBitVector.h" using namespace mlir; namespace { bool isStaticStrideOrOffset(int64_t strideOrOffset) { return !ShapedType::isDynamicStrideOrOffset(strideOrOffset); } struct AllocOpLowering : public AllocLikeOpLLVMLowering { AllocOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), converter) {} LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const { bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions; if (useGenericFn) return LLVM::lookupOrCreateGenericAllocFn(module, getIndexType()); return LLVM::lookupOrCreateMallocFn(module, getIndexType()); } std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { // Heap allocations. memref::AllocOp allocOp = cast(op); MemRefType memRefType = allocOp.getType(); Value alignment; if (auto alignmentAttr = allocOp.getAlignment()) { alignment = createIndexConstant(rewriter, loc, *alignmentAttr); } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { // In the case where no alignment is specified, we may want to override // `malloc's` behavior. `malloc` typically aligns at the size of the // biggest scalar on a target HW. For non-scalars, use the natural // alignment of the LLVM type given by the LLVM DataLayout. alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); } if (alignment) { // Adjust the allocation size to consider alignment. sizeBytes = rewriter.create(loc, sizeBytes, alignment); } // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Type elementPtrType = this->getElementPtrType(memRefType); auto allocFuncOp = getAllocFn(allocOp->getParentOfType()); auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, getVoidPtrType()); Value allocatedPtr = rewriter.create(loc, elementPtrType, results[0]); Value alignedPtr = allocatedPtr; if (alignment) { // Compute the aligned type pointer. Value allocatedInt = rewriter.create(loc, getIndexType(), allocatedPtr); Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment); alignedPtr = rewriter.create(loc, elementPtrType, alignmentInt); } return std::make_tuple(allocatedPtr, alignedPtr); } }; struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { AlignedAllocOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), converter) {} /// Returns the memref's element size in bytes using the data layout active at /// `op`. // TODO: there are other places where this is used. Expose publicly? unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { const DataLayout *layout = &defaultLayout; if (const DataLayoutAnalysis *analysis = getTypeConverter()->getDataLayoutAnalysis()) { layout = &analysis->getAbove(op); } Type elementType = memRefType.getElementType(); if (auto memRefElementType = elementType.dyn_cast()) return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, *layout); if (auto memRefElementType = elementType.dyn_cast()) return getTypeConverter()->getUnrankedMemRefDescriptorSize( memRefElementType, *layout); return layout->getTypeSize(elementType); } /// Returns true if the memref size in bytes is known to be a multiple of /// factor assuming the data layout active at `op`. bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op) const { uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); for (unsigned i = 0, e = type.getRank(); i < e; i++) { if (ShapedType::isDynamic(type.getDimSize(i))) continue; sizeDivisor = sizeDivisor * type.getDimSize(i); } return sizeDivisor % factor == 0; } /// Returns the alignment to be used for the allocation call itself. /// aligned_alloc requires the allocation size to be a power of two, and the /// allocation size to be a multiple of alignment, int64_t getAllocationAlignment(memref::AllocOp allocOp) const { if (Optional alignment = allocOp.getAlignment()) return *alignment; // Whenever we don't have alignment set, we will use an alignment // consistent with the element type; since the allocation size has to be a // power of two, we will bump to the next power of two if it already isn't. auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); return std::max(kMinAlignedAllocAlignment, llvm::PowerOf2Ceil(eltSizeBytes)); } LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const { bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions; if (useGenericFn) return LLVM::lookupOrCreateGenericAlignedAllocFn(module, getIndexType()); return LLVM::lookupOrCreateAlignedAllocFn(module, getIndexType()); } std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { // Heap allocations. memref::AllocOp allocOp = cast(op); MemRefType memRefType = allocOp.getType(); int64_t alignment = getAllocationAlignment(allocOp); Value allocAlignment = createIndexConstant(rewriter, loc, alignment); // aligned_alloc requires size to be a multiple of alignment; we will pad // the size to the next multiple if necessary. if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); auto allocFuncOp = getAllocFn(allocOp->getParentOfType()); auto results = createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, getVoidPtrType()); Value allocatedPtr = rewriter.create(loc, elementPtrType, results[0]); return std::make_tuple(allocatedPtr, allocatedPtr); } /// The minimum alignment to use with aligned_alloc (has to be a power of 2). static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; /// Default layout to use in absence of the corresponding analysis. DataLayout defaultLayout; }; // Out of line definition, required till C++17. constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; struct AllocaOpLowering : public AllocLikeOpLLVMLowering { AllocaOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), converter) {} /// Allocates the underlying buffer using the right call. `allocatedBytePtr` /// is set to null for stack allocations. `accessAlignment` is set if /// alignment is needed post allocation (for eg. in conjunction with malloc). std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { // With alloca, one gets a pointer to the element type right away. // For stack allocations. auto allocaOp = cast(op); auto elementPtrType = this->getElementPtrType(allocaOp.getType()); auto allocatedElementPtr = rewriter.create( loc, elementPtrType, sizeBytes, allocaOp.getAlignment().value_or(0)); return std::make_tuple(allocatedElementPtr, allocatedElementPtr); } }; struct AllocaScopeOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { OpBuilder::InsertionGuard guard(rewriter); Location loc = allocaScopeOp.getLoc(); // Split the current block before the AllocaScopeOp to create the inlining // point. auto *currentBlock = rewriter.getInsertionBlock(); auto *remainingOpsBlock = rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); Block *continueBlock; if (allocaScopeOp.getNumResults() == 0) { continueBlock = remainingOpsBlock; } else { continueBlock = rewriter.createBlock( remainingOpsBlock, allocaScopeOp.getResultTypes(), SmallVector(allocaScopeOp->getNumResults(), allocaScopeOp.getLoc())); rewriter.create(loc, ValueRange(), remainingOpsBlock); } // Inline body region. Block *beforeBody = &allocaScopeOp.getBodyRegion().front(); Block *afterBody = &allocaScopeOp.getBodyRegion().back(); rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock); // Save stack and then branch into the body of the region. rewriter.setInsertionPointToEnd(currentBlock); auto stackSaveOp = rewriter.create(loc, getVoidPtrType()); rewriter.create(loc, ValueRange(), beforeBody); // Replace the alloca_scope return with a branch that jumps out of the body. // Stack restore before leaving the body region. rewriter.setInsertionPointToEnd(afterBody); auto returnOp = cast(afterBody->getTerminator()); auto branchOp = rewriter.replaceOpWithNewOp( returnOp, returnOp.getResults(), continueBlock); // Insert stack restore before jumping out the body of the region. rewriter.setInsertionPoint(branchOp); rewriter.create(loc, stackSaveOp); // Replace the op with values return from the body region. rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); return success(); } }; struct AssumeAlignmentOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value memref = adaptor.getMemref(); unsigned alignment = op.getAlignment(); auto loc = op.getLoc(); MemRefDescriptor memRefDescriptor(memref); Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that // the asserted memref.alignedPtr isn't used anywhere else, as the real // users like load/store/views always re-extract memref.alignedPtr as they // get lowered. // // This relies on LLVM's CSE optimization (potentially after SROA), since // after CSE all memref.alignedPtr instances get de-duplicated into the same // pointer SSA value. auto intPtrType = getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); Value mask = createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); Value ptrValue = rewriter.create(loc, intPtrType, ptr); rewriter.create( loc, rewriter.create( loc, LLVM::ICmpPredicate::eq, rewriter.create(loc, ptrValue, mask), zero)); rewriter.eraseOp(op); return success(); } }; // A `dealloc` is converted into a call to `free` on the underlying data buffer. // The memref descriptor being an SSA value, there is no need to clean it up // in any way. struct DeallocOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; explicit DeallocOpLowering(LLVMTypeConverter &converter) : ConvertOpToLLVMPattern(converter) {} LLVM::LLVMFuncOp getFreeFn(ModuleOp module) const { bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions; if (useGenericFn) return LLVM::lookupOrCreateGenericFreeFn(module); return LLVM::lookupOrCreateFreeFn(module); } LogicalResult matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Insert the `free` declaration if it is not already present. auto freeFunc = getFreeFn(op->getParentOfType()); MemRefDescriptor memref(adaptor.getMemref()); Value casted = rewriter.create( op.getLoc(), getVoidPtrType(), memref.allocatedPtr(rewriter, op.getLoc())); rewriter.replaceOpWithNewOp( op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); return success(); } }; // A `dim` is converted to a constant for static sizes and to an access to the // size stored in the memref descriptor for dynamic sizes. struct DimOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type operandType = dimOp.getSource().getType(); if (operandType.isa()) { rewriter.replaceOp( dimOp, {extractSizeOfUnrankedMemRef( operandType, dimOp, adaptor.getOperands(), rewriter)}); return success(); } if (operandType.isa()) { rewriter.replaceOp( dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, adaptor.getOperands(), rewriter)}); return success(); } llvm_unreachable("expected MemRefType or UnrankedMemRefType"); } private: Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); auto unrankedMemRefType = operandType.cast(); auto scalarMemRefType = MemRefType::get({}, unrankedMemRefType.getElementType()); unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); // Extract pointer to the underlying ranked descriptor and bitcast it to a // memref descriptor pointer to minimize the number of GEP // operations. UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource()); Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); Value scalarMemRefDescPtr = rewriter.create( loc, LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), addressSpace), underlyingRankedDesc); // Get pointer to offset field of memref descriptor. Type indexPtrTy = LLVM::LLVMPointerType::get( getTypeConverter()->getIndexType(), addressSpace); Value two = rewriter.create( loc, typeConverter->convertType(rewriter.getI32Type()), rewriter.getI32IntegerAttr(2)); Value offsetPtr = rewriter.create( loc, indexPtrTy, scalarMemRefDescPtr, ValueRange({createIndexConstant(rewriter, loc, 0), two})); // The size value that we have to extract can be obtained using GEPop with // `dimOp.index() + 1` index argument. Value idxPlusOne = rewriter.create( loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex()); Value sizePtr = rewriter.create(loc, indexPtrTy, offsetPtr, ValueRange({idxPlusOne})); return rewriter.create(loc, sizePtr); } Optional getConstantDimIndex(memref::DimOp dimOp) const { if (Optional idx = dimOp.getConstantIndex()) return idx; if (auto constantOp = dimOp.getIndex().getDefiningOp()) return constantOp.getValue() .cast() .getValue() .getSExtValue(); return llvm::None; } Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); // Take advantage if index is constant. MemRefType memRefType = operandType.cast(); if (Optional index = getConstantDimIndex(dimOp)) { int64_t i = *index; if (memRefType.isDynamicDim(i)) { // extract dynamic size from the memref descriptor. MemRefDescriptor descriptor(adaptor.getSource()); return descriptor.size(rewriter, loc, i); } // Use constant for static size. int64_t dimSize = memRefType.getDimSize(i); return createIndexConstant(rewriter, loc, dimSize); } Value index = adaptor.getIndex(); int64_t rank = memRefType.getRank(); MemRefDescriptor memrefDescriptor(adaptor.getSource()); return memrefDescriptor.size(rewriter, loc, index, rank); } }; /// Common base for load and store operations on MemRefs. Restricts the match /// to supported MemRef types. Provides functionality to emit code accessing a /// specific element of the underlying data buffer. template struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using ConvertOpToLLVMPattern::isConvertibleAndHasIdentityMaps; using Base = LoadStoreOpLowering; LogicalResult match(Derived op) const override { MemRefType type = op.getMemRefType(); return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); } }; /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be /// retried until it succeeds in atomically storing a new value into memory. /// /// +---------------------------------+ /// | | /// | | /// | cf.br loop(%loaded) | /// +---------------------------------+ /// | /// -------| | /// | v v /// | +--------------------------------+ /// | | loop(%loaded): | /// | | | /// | | %pair = cmpxchg | /// | | %ok = %pair[0] | /// | | %new = %pair[1] | /// | | cf.cond_br %ok, end, loop(%new) | /// | +--------------------------------+ /// | | | /// |----------- | /// v /// +--------------------------------+ /// | end: | /// | | /// +--------------------------------+ /// struct GenericAtomicRMWOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = atomicOp.getLoc(); Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); // Split the block into initial, loop, and ending parts. auto *initBlock = rewriter.getInsertionBlock(); auto *loopBlock = rewriter.createBlock( initBlock->getParent(), std::next(Region::iterator(initBlock)), valueType, loc); auto *endBlock = rewriter.createBlock( loopBlock->getParent(), std::next(Region::iterator(loopBlock))); // Operations range to be moved to `endBlock`. auto opsToMoveStart = atomicOp->getIterator(); auto opsToMoveEnd = initBlock->back().getIterator(); // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); auto memRefType = atomicOp.getMemref().getType().cast(); auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), adaptor.getIndices(), rewriter); Value init = rewriter.create(loc, dataPtr); rewriter.create(loc, init, loopBlock); // Prepare the body of the loop block. rewriter.setInsertionPointToStart(loopBlock); // Clone the GenericAtomicRMWOp region and extract the result. auto loopArgument = loopBlock->getArgument(0); BlockAndValueMapping mapping; mapping.map(atomicOp.getCurrentValue(), loopArgument); Block &entryBlock = atomicOp.body().front(); for (auto &nestedOp : entryBlock.without_terminator()) { Operation *clone = rewriter.clone(nestedOp, mapping); mapping.map(nestedOp.getResults(), clone->getResults()); } Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); // Prepare the epilog of the loop block. // Append the cmpxchg op to the end of the loop block. auto successOrdering = LLVM::AtomicOrdering::acq_rel; auto failureOrdering = LLVM::AtomicOrdering::monotonic; auto boolType = IntegerType::get(rewriter.getContext(), 1); auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), {valueType, boolType}); auto cmpxchg = rewriter.create( loc, pairType, dataPtr, loopArgument, result, successOrdering, failureOrdering); // Extract the %new_loaded and %ok values from the pair. Value newLoaded = rewriter.create( loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); Value ok = rewriter.create( loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); // Conditionally branch to the end or back to the loop depending on %ok. rewriter.create(loc, ok, endBlock, ArrayRef(), loopBlock, newLoaded); rewriter.setInsertionPointToEnd(endBlock); moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), std::next(opsToMoveEnd), rewriter); // The 'result' of the atomic_rmw op is the newly loaded value. rewriter.replaceOp(atomicOp, {newLoaded}); return success(); } private: // Clones a segment of ops [start, end) and erases the original. void moveOpsRange(ValueRange oldResult, ValueRange newResult, Block::iterator start, Block::iterator end, ConversionPatternRewriter &rewriter) const { BlockAndValueMapping mapping; mapping.map(oldResult, newResult); SmallVector opsToErase; for (auto it = start; it != end; ++it) { rewriter.clone(*it, mapping); opsToErase.push_back(&*it); } for (auto *it : opsToErase) rewriter.eraseOp(it); } }; /// Returns the LLVM type of the global variable given the memref type `type`. static Type convertGlobalMemrefTypeToLLVM(MemRefType type, LLVMTypeConverter &typeConverter) { // LLVM type for a global memref will be a multi-dimension array. For // declarations or uninitialized global memrefs, we can potentially flatten // this to a 1D array. However, for memref.global's with an initial value, // we do not intend to flatten the ElementsAttribute when going from std -> // LLVM dialect, so the LLVM type needs to me a multi-dimension array. Type elementType = typeConverter.convertType(type.getElementType()); Type arrayTy = elementType; // Shape has the outermost dim at index 0, so need to walk it backwards for (int64_t dim : llvm::reverse(type.getShape())) arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); return arrayTy; } /// GlobalMemrefOp is lowered to a LLVM Global Variable. struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MemRefType type = global.getType(); if (!isConvertibleAndHasIdentityMaps(type)) return failure(); Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); LLVM::Linkage linkage = global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; Attribute initialValue = nullptr; if (!global.isExternal() && !global.isUninitialized()) { auto elementsAttr = global.getInitialValue()->cast(); initialValue = elementsAttr; // For scalar memrefs, the global variable created is of the element type, // so unpack the elements attribute to extract the value. if (type.getRank() == 0) initialValue = elementsAttr.getSplatValue(); } uint64_t alignment = global.getAlignment().value_or(0); auto newGlobal = rewriter.replaceOpWithNewOp( global, arrayTy, global.getConstant(), linkage, global.getSymName(), initialValue, alignment, type.getMemorySpaceAsInt()); if (!global.isExternal() && global.isUninitialized()) { Block *blk = new Block(); newGlobal.getInitializerRegion().push_back(blk); rewriter.setInsertionPointToStart(blk); Value undef[] = { rewriter.create(global.getLoc(), arrayTy)}; rewriter.create(global.getLoc(), undef); } return success(); } }; /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to /// the first element stashed into the descriptor. This reuses /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), converter) {} /// Buffer "allocation" for memref.get_global op is getting the address of /// the global variable referenced. std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { auto getGlobalOp = cast(op); MemRefType type = getGlobalOp.getResult().getType().cast(); unsigned memSpace = type.getMemorySpaceAsInt(); Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); auto addressOf = rewriter.create( loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.getName()); // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. Type elementType = typeConverter->convertType(type.getElementType()); Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); SmallVector operands; operands.insert(operands.end(), type.getRank() + 1, createIndexConstant(rewriter, loc, 0)); auto gep = rewriter.create(loc, elementPtrType, addressOf, operands); // We do not expect the memref obtained using `memref.get_global` to be // ever deallocated. Set the allocated pointer to be known bad value to // help debug if that ever happens. auto intPtrType = getIntPtrType(memSpace); Value deadBeefConst = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); auto deadBeefPtr = rewriter.create(loc, elementPtrType, deadBeefConst); // Both allocated and aligned pointers are same. We could potentially stash // a nullptr for the allocated pointer since we do not expect any dealloc. return std::make_tuple(deadBeefPtr, gep); } }; // Load operation is lowered to obtaining a pointer to the indexed element // and loading it. struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = loadOp.getMemRefType(); Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(), adaptor.getIndices(), rewriter); rewriter.replaceOpWithNewOp(loadOp, dataPtr); return success(); } }; // Store operation is lowered to obtaining a pointer to the indexed element, // and storing the given value to it. struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = op.getMemRefType(); Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(), adaptor.getIndices(), rewriter); rewriter.replaceOpWithNewOp(op, adaptor.getValue(), dataPtr); return success(); } }; // The prefetch operation is lowered in a way similar to the load operation // except that the llvm.prefetch operation is used for replacement. struct PrefetchOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = prefetchOp.getMemRefType(); auto loc = prefetchOp.getLoc(); Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(), adaptor.getIndices(), rewriter); // Replace with llvm.prefetch. auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); auto isWrite = rewriter.create( loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite())); auto localityHint = rewriter.create( loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint())); auto isData = rewriter.create( loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache())); rewriter.replaceOpWithNewOp(prefetchOp, dataPtr, isWrite, localityHint, isData); return success(); } }; struct RankOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type operandType = op.getMemref().getType(); if (auto unrankedMemRefType = operandType.dyn_cast()) { UnrankedMemRefDescriptor desc(adaptor.getMemref()); rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); return success(); } if (auto rankedMemRefType = operandType.dyn_cast()) { rewriter.replaceOp( op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); return success(); } return failure(); } }; struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult match(memref::CastOp memRefCastOp) const override { Type srcType = memRefCastOp.getOperand().getType(); Type dstType = memRefCastOp.getType(); // memref::CastOp reduce to bitcast in the ranked MemRef case and can be // used for type erasure. For now they must preserve underlying element type // and require source and result type to have the same rank. Therefore, // perform a sanity check that the underlying structs are the same. Once op // semantics are relaxed we can revisit. if (srcType.isa() && dstType.isa()) return success(typeConverter->convertType(srcType) == typeConverter->convertType(dstType)); // At least one of the operands is unranked type assert(srcType.isa() || dstType.isa()); // Unranked to unranked cast is disallowed return !(srcType.isa() && dstType.isa()) ? success() : failure(); } void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = memRefCastOp.getOperand().getType(); auto dstType = memRefCastOp.getType(); auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); auto loc = memRefCastOp.getLoc(); // For ranked/ranked case, just keep the original descriptor. if (srcType.isa() && dstType.isa()) return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); if (srcType.isa() && dstType.isa()) { // Casting ranked to unranked memref type // Set the rank in the destination from the memref type // Allocate space on the stack and copy the src memref descriptor // Set the ptr in the destination to the stack space auto srcMemRefType = srcType.cast(); int64_t rank = srcMemRefType.getRank(); // ptr = AllocaOp sizeof(MemRefDescriptor) auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( loc, adaptor.getSource(), rewriter); // voidptr = BitCastOp srcType* to void* auto voidPtr = rewriter.create(loc, getVoidPtrType(), ptr) .getResult(); // rank = ConstantOp srcRank auto rankVal = rewriter.create( loc, getIndexType(), rewriter.getIndexAttr(rank)); // undef = UndefOp UnrankedMemRefDescriptor memRefDesc = UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); // d1 = InsertValueOp undef, rank, 0 memRefDesc.setRank(rewriter, loc, rankVal); // d2 = InsertValueOp d1, voidptr, 1 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); } else if (srcType.isa() && dstType.isa()) { // Casting from unranked type to ranked. // The operation is assumed to be doing a correct cast. If the destination // type mismatches the unranked the type, it is undefined behavior. UnrankedMemRefDescriptor memRefDesc(adaptor.getSource()); // ptr = ExtractValueOp src, 1 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); // castPtr = BitCastOp i8* to structTy* auto castPtr = rewriter .create( loc, LLVM::LLVMPointerType::get(targetStructType), ptr) .getResult(); // struct = LoadOp castPtr auto loadOp = rewriter.create(loc, castPtr); rewriter.replaceOp(memRefCastOp, loadOp.getResult()); } else { llvm_unreachable("Unsupported unranked memref to unranked memref cast"); } } }; /// Pattern to lower a `memref.copy` to llvm. /// /// For memrefs with identity layouts, the copy is lowered to the llvm /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call /// to the generic `MemrefCopyFn`. struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); auto srcType = op.getSource().getType().dyn_cast(); MemRefDescriptor srcDesc(adaptor.getSource()); // Compute number of elements. Value numElements = rewriter.create( loc, getIndexType(), rewriter.getIndexAttr(1)); for (int pos = 0; pos < srcType.getRank(); ++pos) { auto size = srcDesc.size(rewriter, loc, pos); numElements = rewriter.create(loc, numElements, size); } // Get element size. auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); // Compute total. Value totalSize = rewriter.create(loc, numElements, sizeInBytes); Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); Value srcOffset = srcDesc.offset(rewriter, loc); Value srcPtr = rewriter.create(loc, srcBasePtr.getType(), srcBasePtr, srcOffset); MemRefDescriptor targetDesc(adaptor.getTarget()); Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); Value targetOffset = targetDesc.offset(rewriter, loc); Value targetPtr = rewriter.create(loc, targetBasePtr.getType(), targetBasePtr, targetOffset); Value isVolatile = rewriter.create( loc, typeConverter->convertType(rewriter.getI1Type()), rewriter.getBoolAttr(false)); rewriter.create(loc, targetPtr, srcPtr, totalSize, isVolatile); rewriter.eraseOp(op); return success(); } LogicalResult lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); auto srcType = op.getSource().getType().cast(); auto targetType = op.getTarget().getType().cast(); // First make sure we have an unranked memref descriptor representation. auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { auto rank = rewriter.create( loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); auto *typeConverter = getTypeConverter(); auto ptr = typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); auto voidPtr = rewriter.create(loc, getVoidPtrType(), ptr) .getResult(); auto unrankedType = UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, voidPtr}); }; Value unrankedSource = srcType.hasRank() ? makeUnranked(adaptor.getSource(), srcType) : adaptor.getSource(); Value unrankedTarget = targetType.hasRank() ? makeUnranked(adaptor.getTarget(), targetType) : adaptor.getTarget(); // Now promote the unranked descriptors to the stack. auto one = rewriter.create(loc, getIndexType(), rewriter.getIndexAttr(1)); auto promote = [&](Value desc) { auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); auto allocated = rewriter.create(loc, ptrType, ValueRange{one}); rewriter.create(loc, desc, allocated); return allocated; }; auto sourcePtr = promote(unrankedSource); auto targetPtr = promote(unrankedTarget); unsigned typeSize = mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType()); auto elemSize = rewriter.create( loc, getIndexType(), rewriter.getIndexAttr(typeSize)); auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( op->getParentOfType(), getIndexType(), sourcePtr.getType()); rewriter.create(loc, copyFn, ValueRange{elemSize, sourcePtr, targetPtr}); rewriter.eraseOp(op); return success(); } LogicalResult matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getSource().getType().cast(); auto targetType = op.getTarget().getType().cast(); auto isContiguousMemrefType = [](BaseMemRefType type) { auto memrefType = type.dyn_cast(); // We can use memcpy for memrefs if they have an identity layout or are // contiguous with an arbitrary offset. Ignore empty memrefs, which is a // special case handled by memrefCopy. return memrefType && (memrefType.getLayout().isIdentity() || (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && isStaticShapeAndContiguousRowMajor(memrefType))); }; if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) return lowerToMemCopyIntrinsic(op, adaptor, rewriter); return lowerToMemCopyFunctionCall(op, adaptor, rewriter); } }; /// Extracts allocated, aligned pointers and offset from a ranked or unranked /// memref type. In unranked case, the fields are extracted from the underlying /// ranked descriptor. static void extractPointersAndOffset(Location loc, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Value originalOperand, Value convertedOperand, Value *allocatedPtr, Value *alignedPtr, Value *offset = nullptr) { Type operandType = originalOperand.getType(); if (operandType.isa()) { MemRefDescriptor desc(convertedOperand); *allocatedPtr = desc.allocatedPtr(rewriter, loc); *alignedPtr = desc.alignedPtr(rewriter, loc); if (offset != nullptr) *offset = desc.offset(rewriter, loc); return; } unsigned memorySpace = operandType.cast().getMemorySpaceAsInt(); Type elementType = operandType.cast().getElementType(); Type llvmElementType = typeConverter.convertType(elementType); Type elementPtrPtrType = LLVM::LLVMPointerType::get( LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); // Extract pointer to the underlying ranked memref descriptor and cast it to // ElemType**. UnrankedMemRefDescriptor unrankedDesc(convertedOperand); Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( rewriter, loc, underlyingDescPtr, elementPtrPtrType); *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); if (offset != nullptr) { *offset = UnrankedMemRefDescriptor::offset( rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); } } struct MemRefReinterpretCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = castOp.getSource().getType(); Value descriptor; if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, adaptor, &descriptor))) return failure(); rewriter.replaceOp(castOp, {descriptor}); return success(); } private: LogicalResult convertSourceMemRefToDescriptor( ConversionPatternRewriter &rewriter, Type srcType, memref::ReinterpretCastOp castOp, memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { MemRefType targetMemRefType = castOp.getResult().getType().cast(); auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) .dyn_cast_or_null(); if (!llvmTargetDescriptorTy) return failure(); // Create descriptor. Location loc = castOp.getLoc(); auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); // Set allocated and aligned pointers. Value allocatedPtr, alignedPtr; extractPointersAndOffset(loc, rewriter, *getTypeConverter(), castOp.getSource(), adaptor.getSource(), &allocatedPtr, &alignedPtr); desc.setAllocatedPtr(rewriter, loc, allocatedPtr); desc.setAlignedPtr(rewriter, loc, alignedPtr); // Set offset. if (castOp.isDynamicOffset(0)) desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]); else desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); // Set sizes and strides. unsigned dynSizeId = 0; unsigned dynStrideId = 0; for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { if (castOp.isDynamicSize(i)) desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]); else desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); if (castOp.isDynamicStride(i)) desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]); else desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); } *descriptor = desc; return success(); } }; struct MemRefReshapeOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = reshapeOp.getSource().getType(); Value descriptor; if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, adaptor, &descriptor))) return failure(); rewriter.replaceOp(reshapeOp, {descriptor}); return success(); } private: LogicalResult convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, Type srcType, memref::ReshapeOp reshapeOp, memref::ReshapeOp::Adaptor adaptor, Value *descriptor) const { auto shapeMemRefType = reshapeOp.getShape().getType().cast(); if (shapeMemRefType.hasStaticShape()) { MemRefType targetMemRefType = reshapeOp.getResult().getType().cast(); auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) .dyn_cast_or_null(); if (!llvmTargetDescriptorTy) return failure(); // Create descriptor. Location loc = reshapeOp.getLoc(); auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); // Set allocated and aligned pointers. Value allocatedPtr, alignedPtr; extractPointersAndOffset(loc, rewriter, *getTypeConverter(), reshapeOp.getSource(), adaptor.getSource(), &allocatedPtr, &alignedPtr); desc.setAllocatedPtr(rewriter, loc, allocatedPtr); desc.setAlignedPtr(rewriter, loc, alignedPtr); // Extract the offset and strides from the type. int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(targetMemRefType, strides, offset))) return rewriter.notifyMatchFailure( reshapeOp, "failed to get stride and offset exprs"); if (!isStaticStrideOrOffset(offset)) return rewriter.notifyMatchFailure(reshapeOp, "dynamic offset is unsupported"); desc.setConstantOffset(rewriter, loc, offset); assert(targetMemRefType.getLayout().isIdentity() && "Identity layout map is a precondition of a valid reshape op"); Value stride = nullptr; int64_t targetRank = targetMemRefType.getRank(); for (auto i : llvm::reverse(llvm::seq(0, targetRank))) { if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { // If the stride for this dimension is dynamic, then use the product // of the sizes of the inner dimensions. stride = createIndexConstant(rewriter, loc, strides[i]); } else if (!stride) { // `stride` is null only in the first iteration of the loop. However, // since the target memref has an identity layout, we can safely set // the innermost stride to 1. stride = createIndexConstant(rewriter, loc, 1); } Value dimSize; int64_t size = targetMemRefType.getDimSize(i); // If the size of this dimension is dynamic, then load it at runtime // from the shape operand. if (!ShapedType::isDynamic(size)) { dimSize = createIndexConstant(rewriter, loc, size); } else { Value shapeOp = reshapeOp.getShape(); Value index = createIndexConstant(rewriter, loc, i); dimSize = rewriter.create(loc, shapeOp, index); Type indexType = getIndexType(); if (dimSize.getType() != indexType) dimSize = typeConverter->materializeTargetConversion( rewriter, loc, indexType, dimSize); assert(dimSize && "Invalid memref element type"); } desc.setSize(rewriter, loc, i, dimSize); desc.setStride(rewriter, loc, i, stride); // Prepare the stride value for the next dimension. stride = rewriter.create(loc, stride, dimSize); } *descriptor = desc; return success(); } // The shape is a rank-1 tensor with unknown length. Location loc = reshapeOp.getLoc(); MemRefDescriptor shapeDesc(adaptor.getShape()); Value resultRank = shapeDesc.size(rewriter, loc, 0); // Extract address space and element type. auto targetType = reshapeOp.getResult().getType().cast(); unsigned addressSpace = targetType.getMemorySpaceAsInt(); Type elementType = targetType.getElementType(); // Create the unranked memref descriptor that holds the ranked one. The // inner descriptor is allocated on stack. auto targetDesc = UnrankedMemRefDescriptor::undef( rewriter, loc, typeConverter->convertType(targetType)); targetDesc.setRank(rewriter, loc, resultRank); SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), targetDesc, sizes); Value underlyingDescPtr = rewriter.create( loc, getVoidPtrType(), sizes.front(), llvm::None); targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); // Extract pointers and offset from the source memref. Value allocatedPtr, alignedPtr, offset; extractPointersAndOffset(loc, rewriter, *getTypeConverter(), reshapeOp.getSource(), adaptor.getSource(), &allocatedPtr, &alignedPtr, &offset); // Set pointers and offset. Type llvmElementType = typeConverter->convertType(elementType); auto elementPtrPtrType = LLVM::LLVMPointerType::get( LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, elementPtrPtrType, allocatedPtr); UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrPtrType, alignedPtr); UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrPtrType, offset); // Use the offset pointer as base for further addressing. Copy over the new // shape and compute strides. For this, we create a loop from rank-1 to 0. Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrPtrType); Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); Value oneIndex = createIndexConstant(rewriter, loc, 1); Value resultRankMinusOne = rewriter.create(loc, resultRank, oneIndex); Block *initBlock = rewriter.getInsertionBlock(); Type indexType = getTypeConverter()->getIndexType(); Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, {indexType, indexType}, {loc, loc}); // Move the remaining initBlock ops to condBlock. Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); rewriter.setInsertionPointToEnd(initBlock); rewriter.create(loc, ValueRange({resultRankMinusOne, oneIndex}), condBlock); rewriter.setInsertionPointToStart(condBlock); Value indexArg = condBlock->getArgument(0); Value strideArg = condBlock->getArgument(1); Value zeroIndex = createIndexConstant(rewriter, loc, 0); Value pred = rewriter.create( loc, IntegerType::get(rewriter.getContext(), 1), LLVM::ICmpPredicate::sge, indexArg, zeroIndex); Block *bodyBlock = rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); rewriter.setInsertionPointToStart(bodyBlock); // Copy size from shape to descriptor. Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); Value sizeLoadGep = rewriter.create( loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); Value size = rewriter.create(loc, sizeLoadGep); UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), targetSizesBase, indexArg, size); // Write stride value and compute next one. UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), targetStridesBase, indexArg, strideArg); Value nextStride = rewriter.create(loc, strideArg, size); // Decrement loop counter and branch back. Value decrement = rewriter.create(loc, indexArg, oneIndex); rewriter.create(loc, ValueRange({decrement, nextStride}), condBlock); Block *remainder = rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); // Hook up the cond exit to the remainder. rewriter.setInsertionPointToEnd(condBlock); rewriter.create(loc, pred, bodyBlock, llvm::None, remainder, llvm::None); // Reset position to beginning of new remainder block. rewriter.setInsertionPointToStart(remainder); *descriptor = targetDesc; return success(); } }; /// Helper function to convert a vector of `OpFoldResult`s into a vector of /// `Value`s. static SmallVector getAsValues(OpBuilder &b, Location loc, Type &llvmIndexType, ArrayRef valueOrAttrVec) { return llvm::to_vector<4>( llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { if (auto attr = value.dyn_cast()) return b.create(loc, llvmIndexType, attr); return value.get(); })); } /// Compute a map that for a given dimension of the expanded type gives the /// dimension in the collapsed type it maps to. Essentially its the inverse of /// the `reassocation` maps. static DenseMap getExpandedDimToCollapsedDimMap(ArrayRef reassociation) { llvm::DenseMap expandedDimToCollapsedDim; for (auto &en : enumerate(reassociation)) { for (auto dim : en.value()) expandedDimToCollapsedDim[dim] = en.index(); } return expandedDimToCollapsedDim; } static OpFoldResult getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, ArrayRef outStaticShape, MemRefDescriptor &inDesc, ArrayRef inStaticShape, ArrayRef reassocation, DenseMap &outDimToInDimMap) { int64_t outDimSize = outStaticShape[outDimIndex]; if (!ShapedType::isDynamic(outDimSize)) return b.getIndexAttr(outDimSize); // Calculate the multiplication of all the out dim sizes except the // current dim. int64_t inDimIndex = outDimToInDimMap[outDimIndex]; int64_t otherDimSizesMul = 1; for (auto otherDimIndex : reassocation[inDimIndex]) { if (otherDimIndex == static_cast(outDimIndex)) continue; int64_t otherDimSize = outStaticShape[otherDimIndex]; assert(!ShapedType::isDynamic(otherDimSize) && "single dimension cannot be expanded into multiple dynamic " "dimensions"); otherDimSizesMul *= otherDimSize; } // outDimSize = inDimSize / otherOutDimSizesMul int64_t inDimSize = inStaticShape[inDimIndex]; Value inDimSizeDynamic = ShapedType::isDynamic(inDimSize) ? inDesc.size(b, loc, inDimIndex) : b.create(loc, llvmIndexType, b.getIndexAttr(inDimSize)); Value outDimSizeDynamic = b.create( loc, inDimSizeDynamic, b.create(loc, llvmIndexType, b.getIndexAttr(otherDimSizesMul))); return outDimSizeDynamic; } static OpFoldResult getCollapsedOutputDimSize( OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, int64_t outDimSize, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef reassocation) { if (!ShapedType::isDynamic(outDimSize)) return b.getIndexAttr(outDimSize); Value c1 = b.create(loc, llvmIndexType, b.getIndexAttr(1)); Value outDimSizeDynamic = c1; for (auto inDimIndex : reassocation[outDimIndex]) { int64_t inDimSize = inStaticShape[inDimIndex]; Value inDimSizeDynamic = ShapedType::isDynamic(inDimSize) ? inDesc.size(b, loc, inDimIndex) : b.create(loc, llvmIndexType, b.getIndexAttr(inDimSize)); outDimSizeDynamic = b.create(loc, outDimSizeDynamic, inDimSizeDynamic); } return outDimSizeDynamic; } static SmallVector getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { return llvm::to_vector<4>(llvm::map_range( llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, outStaticShape[outDimIndex], inStaticShape, inDesc, reassociation); })); } static SmallVector getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { DenseMap outDimToInDimMap = getExpandedDimToCollapsedDimMap(reassociation); return llvm::to_vector<4>(llvm::map_range( llvm::seq(0, outStaticShape.size()), [&](int64_t outDimIndex) { return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, outStaticShape, inDesc, inStaticShape, reassociation, outDimToInDimMap); })); } static SmallVector getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, ArrayRef reassociation, ArrayRef inStaticShape, MemRefDescriptor &inDesc, ArrayRef outStaticShape) { return outStaticShape.size() < inStaticShape.size() ? getAsValues(b, loc, llvmIndexType, getCollapsedOutputShape(b, loc, llvmIndexType, reassociation, inStaticShape, inDesc, outStaticShape)) : getAsValues(b, loc, llvmIndexType, getExpandedOutputShape(b, loc, llvmIndexType, reassociation, inStaticShape, inDesc, outStaticShape)); } static void fillInStridesForExpandedMemDescriptor( OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, ArrayRef reassociation) { // See comments for computeExpandedLayoutMap for details on how the strides // are calculated. for (auto &en : llvm::enumerate(reassociation)) { auto currentStrideToExpand = srcDesc.stride(b, loc, en.index()); for (auto dstIndex : llvm::reverse(en.value())) { dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand); Value size = dstDesc.size(b, loc, dstIndex); currentStrideToExpand = b.create(loc, size, currentStrideToExpand); } } } static void fillInStridesForCollapsedMemDescriptor( ConversionPatternRewriter &rewriter, Location loc, Operation *op, TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, ArrayRef reassociation) { // See comments for computeCollapsedLayoutMap for details on how the strides // are calculated. auto srcShape = srcType.getShape(); for (auto &en : llvm::enumerate(reassociation)) { rewriter.setInsertionPoint(op); auto dstIndex = en.index(); ArrayRef ref = llvm::makeArrayRef(en.value()); while (srcShape[ref.back()] == 1 && ref.size() > 1) ref = ref.drop_back(); if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { dstDesc.setStride(rewriter, loc, dstIndex, srcDesc.stride(rewriter, loc, ref.back())); } else { // Iterate over the source strides in reverse order. Skip over the // dimensions whose size is 1. // TODO: we should take the minimum stride in the reassociation group // instead of just the first where the dimension is not 1. // // +------------------------------------------------------+ // | curEntry: | // | %srcStride = strides[srcIndex] | // | %neOne = cmp sizes[srcIndex],1 +--+ // | cf.cond_br %neOne, continue(%srcStride), nextEntry | | // +-------------------------+----------------------------+ | // | | // v | // +-----------------------------+ | // | nextEntry: | | // | ... +---+ | // +--------------+--------------+ | | // | | | // v | | // +-----------------------------+ | | // | nextEntry: | | | // | ... | | | // +--------------+--------------+ | +--------+ // | | | // v v v // +--------------------------------------------------+ // | continue(%newStride): | // | %newMemRefDes = setStride(%newStride,dstIndex) | // +--------------------------------------------------+ OpBuilder::InsertionGuard guard(rewriter); Block *initBlock = rewriter.getInsertionBlock(); Block *continueBlock = rewriter.splitBlock(initBlock, rewriter.getInsertionPoint()); continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc); rewriter.setInsertionPointToStart(continueBlock); dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0)); Block *curEntryBlock = initBlock; Block *nextEntryBlock; for (auto srcIndex : llvm::reverse(ref)) { if (srcShape[srcIndex] == 1 && srcIndex != ref.front()) continue; rewriter.setInsertionPointToEnd(curEntryBlock); Value srcStride = srcDesc.stride(rewriter, loc, srcIndex); if (srcIndex == ref.front()) { rewriter.create(loc, srcStride, continueBlock); break; } Value one = rewriter.create( loc, typeConverter->convertType(rewriter.getI64Type()), rewriter.getI32IntegerAttr(1)); Value predNeOne = rewriter.create( loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex), one); { OpBuilder::InsertionGuard guard(rewriter); nextEntryBlock = rewriter.createBlock( initBlock->getParent(), Region::iterator(continueBlock), {}); } rewriter.create(loc, predNeOne, continueBlock, srcStride, nextEntryBlock, llvm::None); curEntryBlock = nextEntryBlock; } } } } static void fillInDynamicStridesForMemDescriptor( ConversionPatternRewriter &b, Location loc, Operation *op, TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType, MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, ArrayRef reassociation) { if (srcType.getRank() > dstType.getRank()) fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType, srcDesc, dstDesc, reassociation); else fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, reassociation); } // ReshapeOp creates a new view descriptor of the proper rank. // For now, the only conversion supported is for target MemRef with static sizes // and strides. template class ReassociatingReshapeOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; LogicalResult matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { MemRefType dstType = reshapeOp.getResultType(); MemRefType srcType = reshapeOp.getSrcType(); int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(dstType, strides, offset))) { return rewriter.notifyMatchFailure( reshapeOp, "failed to get stride and offset exprs"); } MemRefDescriptor srcDesc(adaptor.getSrc()); Location loc = reshapeOp->getLoc(); auto dstDesc = MemRefDescriptor::undef( rewriter, loc, this->typeConverter->convertType(dstType)); dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); ArrayRef srcStaticShape = srcType.getShape(); ArrayRef dstStaticShape = dstType.getShape(); Type llvmIndexType = this->typeConverter->convertType(rewriter.getIndexType()); SmallVector dstShape = getDynamicOutputShape( rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), srcStaticShape, srcDesc, dstStaticShape); for (auto &en : llvm::enumerate(dstShape)) dstDesc.setSize(rewriter, loc, en.index(), en.value()); if (llvm::all_of(strides, isStaticStrideOrOffset)) { for (auto &en : llvm::enumerate(strides)) dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); } else if (srcType.getLayout().isIdentity() && dstType.getLayout().isIdentity()) { Value c1 = rewriter.create(loc, llvmIndexType, rewriter.getIndexAttr(1)); Value stride = c1; for (auto dimIndex : llvm::reverse(llvm::seq(0, dstShape.size()))) { dstDesc.setStride(rewriter, loc, dimIndex, stride); stride = rewriter.create(loc, dstShape[dimIndex], stride); } } else { // There could be mixed static/dynamic strides. For simplicity, we // recompute all strides if there is at least one dynamic stride. fillInDynamicStridesForMemDescriptor( rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType, srcDesc, dstDesc, reshapeOp.getReassociationIndices()); } rewriter.replaceOp(reshapeOp, {dstDesc}); return success(); } }; /// Conversion pattern that transforms a subview op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size /// and stride. /// The subview op is replaced by the descriptor. struct SubViewOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = subViewOp.getLoc(); auto sourceMemRefType = subViewOp.getSource().getType().cast(); auto sourceElementTy = typeConverter->convertType(sourceMemRefType.getElementType()); auto viewMemRefType = subViewOp.getType(); auto inferredType = memref::SubViewOp::inferResultType( subViewOp.getSourceType(), extractFromI64ArrayAttr(subViewOp.getStaticOffsets()), extractFromI64ArrayAttr(subViewOp.getStaticSizes()), extractFromI64ArrayAttr(subViewOp.getStaticStrides())) .cast(); auto targetElementTy = typeConverter->convertType(viewMemRefType.getElementType()); auto targetDescTy = typeConverter->convertType(viewMemRefType); if (!sourceElementTy || !targetDescTy || !targetElementTy || !LLVM::isCompatibleType(sourceElementTy) || !LLVM::isCompatibleType(targetElementTy) || !LLVM::isCompatibleType(targetDescTy)) return failure(); // Extract the offset and strides from the type. int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(inferredType, strides, offset); if (failed(successStrides)) return failure(); // Create the descriptor. if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) return failure(); MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); Value bitcastPtr = rewriter.create( loc, LLVM::LLVMPointerType::get(targetElementTy, viewMemRefType.getMemorySpaceAsInt()), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Copy the aligned pointer from the old descriptor to the new one. extracted = sourceMemRef.alignedPtr(rewriter, loc); bitcastPtr = rewriter.create( loc, LLVM::LLVMPointerType::get(targetElementTy, viewMemRefType.getMemorySpaceAsInt()), extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); size_t inferredShapeRank = inferredType.getRank(); size_t resultShapeRank = viewMemRefType.getRank(); // Extract strides needed to compute offset. SmallVector strideValues; strideValues.reserve(inferredShapeRank); for (unsigned i = 0; i < inferredShapeRank; ++i) strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); // Offset. auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); if (!ShapedType::isDynamicStrideOrOffset(offset)) { targetMemRef.setConstantOffset(rewriter, loc, offset); } else { Value baseOffset = sourceMemRef.offset(rewriter, loc); // `inferredShapeRank` may be larger than the number of offset operands // because of trailing semantics. In this case, the offset is guaranteed // to be interpreted as 0 and we can just skip the extra dimensions. for (unsigned i = 0, e = std::min(inferredShapeRank, subViewOp.getMixedOffsets().size()); i < e; ++i) { Value offset = // TODO: need OpFoldResult ODS adaptor to clean this up. subViewOp.isDynamicOffset(i) ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); Value mul = rewriter.create(loc, offset, strideValues[i]); baseOffset = rewriter.create(loc, baseOffset, mul); } targetMemRef.setOffset(rewriter, loc, baseOffset); } // Update sizes and strides. SmallVector mixedSizes = subViewOp.getMixedSizes(); SmallVector mixedStrides = subViewOp.getMixedStrides(); assert(mixedSizes.size() == mixedStrides.size() && "expected sizes and strides of equal length"); llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; i >= 0 && j >= 0; --i) { if (unusedDims.test(i)) continue; // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. // In this case, the size is guaranteed to be interpreted as Dim and the // stride as 1. Value size, stride; if (static_cast(i) >= mixedSizes.size()) { // If the static size is available, use it directly. This is similar to // the folding of dim(constant-op) but removes the need for dim to be // aware of LLVM constants and for this pass to be aware of std // constants. int64_t staticSize = subViewOp.getSource().getType().cast().getShape()[i]; if (staticSize != ShapedType::kDynamicSize) { size = rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); } else { Value pos = rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); Value dim = rewriter.create(loc, subViewOp.getSource(), pos); auto cast = rewriter.create( loc, llvmIndexType, dim); size = cast.getResult(0); } stride = rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); } else { // TODO: need OpFoldResult ODS adaptor to clean this up. size = subViewOp.isDynamicSize(i) ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { stride = rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); } else { stride = subViewOp.isDynamicStride(i) ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr( subViewOp.getStaticStride(i))); stride = rewriter.create(loc, stride, strideValues[i]); } } targetMemRef.setSize(rewriter, loc, j, size); targetMemRef.setStride(rewriter, loc, j, stride); j--; } rewriter.replaceOp(subViewOp, {targetMemRef}); return success(); } }; /// Conversion pattern that transforms a transpose op into: /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. /// 2. A load of the ViewDescriptor from the pointer allocated in 1. /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size /// and stride. Size and stride are permutations of the original values. /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. /// The transpose op is replaced by the alloca'ed pointer. class TransposeOpLowering : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = transposeOp.getLoc(); MemRefDescriptor viewMemRef(adaptor.getIn()); // No permutation, early exit. if (transposeOp.getPermutation().isIdentity()) return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); auto targetMemRef = MemRefDescriptor::undef( rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); // Copy the base and aligned pointers from the old descriptor to the new // one. targetMemRef.setAllocatedPtr(rewriter, loc, viewMemRef.allocatedPtr(rewriter, loc)); targetMemRef.setAlignedPtr(rewriter, loc, viewMemRef.alignedPtr(rewriter, loc)); // Copy the offset pointer from the old descriptor to the new one. targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); // Iterate over the dimensions and apply size/stride permutation. for (const auto &en : llvm::enumerate(transposeOp.getPermutation().getResults())) { int sourcePos = en.index(); int targetPos = en.value().cast().getPosition(); targetMemRef.setSize(rewriter, loc, targetPos, viewMemRef.size(rewriter, loc, sourcePos)); targetMemRef.setStride(rewriter, loc, targetPos, viewMemRef.stride(rewriter, loc, sourcePos)); } rewriter.replaceOp(transposeOp, {targetMemRef}); return success(); } }; /// Conversion pattern that transforms an op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size /// and stride. /// The view op is replaced by the descriptor. struct ViewOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Build and return the value for the idx^th shape dimension, either by // returning the constant shape dimension or counting the proper dynamic size. Value getSize(ConversionPatternRewriter &rewriter, Location loc, ArrayRef shape, ValueRange dynamicSizes, unsigned idx) const { assert(idx < shape.size()); if (!ShapedType::isDynamic(shape[idx])) return createIndexConstant(rewriter, loc, shape[idx]); // Count the number of dynamic dims in range [0, idx] unsigned nDynamic = llvm::count_if(shape.take_front(idx), ShapedType::isDynamic); return dynamicSizes[nDynamic]; } // Build and return the idx^th stride, either by returning the constant stride // or by computing the dynamic stride from the current `runningStride` and // `nextSize`. The caller should keep a running stride and update it with the // result returned by this function. Value getStride(ConversionPatternRewriter &rewriter, Location loc, ArrayRef strides, Value nextSize, Value runningStride, unsigned idx) const { assert(idx < strides.size()); if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) return createIndexConstant(rewriter, loc, strides[idx]); if (nextSize) return runningStride ? rewriter.create(loc, runningStride, nextSize) : nextSize; assert(!runningStride); return createIndexConstant(rewriter, loc, 1); } LogicalResult matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = viewOp.getLoc(); auto viewMemRefType = viewOp.getType(); auto targetElementTy = typeConverter->convertType(viewMemRefType.getElementType()); auto targetDescTy = typeConverter->convertType(viewMemRefType); if (!targetDescTy || !targetElementTy || !LLVM::isCompatibleType(targetElementTy) || !LLVM::isCompatibleType(targetDescTy)) return viewOp.emitWarning("Target descriptor type not converted to LLVM"), failure(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); if (failed(successStrides)) return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); assert(offset == 0 && "expected offset to be 0"); // Target memref must be contiguous in memory (innermost stride is 1), or // empty (special case when at least one of the memref dimensions is 0). if (!strides.empty() && (strides.back() != 1 && strides.back() != 0)) return viewOp.emitWarning("cannot cast to non-contiguous shape"), failure(); // Create the descriptor. MemRefDescriptor sourceMemRef(adaptor.getSource()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Field 1: Copy the allocated pointer, used for malloc/free. Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); auto srcMemRefType = viewOp.getSource().getType().cast(); Value bitcastPtr = rewriter.create( loc, LLVM::LLVMPointerType::get(targetElementTy, srcMemRefType.getMemorySpaceAsInt()), allocatedPtr); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Field 2: Copy the actual aligned pointer to payload. Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); alignedPtr = rewriter.create( loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift()); bitcastPtr = rewriter.create( loc, LLVM::LLVMPointerType::get(targetElementTy, srcMemRefType.getMemorySpaceAsInt()), alignedPtr); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); // Field 3: The offset in the resulting type must be 0. This is because of // the type change: an offset on srcType* may not be expressible as an // offset on dstType*. targetMemRef.setOffset(rewriter, loc, createIndexConstant(rewriter, loc, offset)); // Early exit for 0-D corner case. if (viewMemRefType.getRank() == 0) return rewriter.replaceOp(viewOp, {targetMemRef}), success(); // Fields 4 and 5: Update sizes and strides. Value stride = nullptr, nextSize = nullptr; for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. Value size = getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.getSizes(), i); targetMemRef.setSize(rewriter, loc, i, size); // Update stride. stride = getStride(rewriter, loc, strides, nextSize, stride, i); targetMemRef.setStride(rewriter, loc, i, stride); nextSize = size; } rewriter.replaceOp(viewOp, {targetMemRef}); return success(); } }; //===----------------------------------------------------------------------===// // AtomicRMWOpLowering //===----------------------------------------------------------------------===// /// Try to match the kind of a memref.atomic_rmw to determine whether to use a /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. static Optional matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { switch (atomicOp.getKind()) { case arith::AtomicRMWKind::addf: return LLVM::AtomicBinOp::fadd; case arith::AtomicRMWKind::addi: return LLVM::AtomicBinOp::add; case arith::AtomicRMWKind::assign: return LLVM::AtomicBinOp::xchg; case arith::AtomicRMWKind::maxs: return LLVM::AtomicBinOp::max; case arith::AtomicRMWKind::maxu: return LLVM::AtomicBinOp::umax; case arith::AtomicRMWKind::mins: return LLVM::AtomicBinOp::min; case arith::AtomicRMWKind::minu: return LLVM::AtomicBinOp::umin; case arith::AtomicRMWKind::ori: return LLVM::AtomicBinOp::_or; case arith::AtomicRMWKind::andi: return LLVM::AtomicBinOp::_and; default: return llvm::None; } llvm_unreachable("Invalid AtomicRMWKind"); } struct AtomicRMWOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(match(atomicOp))) return failure(); auto maybeKind = matchSimpleAtomicOp(atomicOp); if (!maybeKind) return failure(); auto resultType = adaptor.getValue().getType(); auto memRefType = atomicOp.getMemRefType(); auto dataPtr = getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(), adaptor.getIndices(), rewriter); rewriter.replaceOpWithNewOp( atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(), LLVM::AtomicOrdering::acq_rel); return success(); } }; } // namespace void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { // clang-format off patterns.add< AllocaOpLowering, AllocaScopeOpLowering, AtomicRMWOpLowering, AssumeAlignmentOpLowering, DimOpLowering, GenericAtomicRMWOpLowering, GlobalMemrefOpLowering, GetGlobalMemrefOpLowering, LoadOpLowering, MemRefCastOpLowering, MemRefCopyOpLowering, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, PrefetchOpLowering, RankOpLowering, ReassociatingReshapeOpConversion, ReassociatingReshapeOpConversion, StoreOpLowering, SubViewOpLowering, TransposeOpLowering, ViewOpLowering>(converter); // clang-format on auto allocLowering = converter.getOptions().allocLowering; if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) patterns.add(converter); else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) patterns.add(converter); } namespace { struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase { MemRefToLLVMPass() = default; void runOnOperation() override { Operation *op = getOperation(); const auto &dataLayoutAnalysis = getAnalysis(); LowerToLLVMOptions options(&getContext(), dataLayoutAnalysis.getAtOrAbove(op)); options.allocLowering = (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc : LowerToLLVMOptions::AllocLowering::Malloc); options.useGenericFunctions = useGenericFunctions; if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); LLVMTypeConverter typeConverter(&getContext(), options, &dataLayoutAnalysis); RewritePatternSet patterns(&getContext()); populateMemRefToLLVMConversionPatterns(typeConverter, patterns); LLVMConversionTarget target(getContext()); target.addLegalOp(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } }; } // namespace std::unique_ptr mlir::createMemRefToLLVMPass() { return std::make_unique(); }