1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" 13 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 16 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 20 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 21 #include "mlir/Dialect/Linalg/Passes.h" 22 #include "mlir/Dialect/SCF/SCF.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Attributes.h" 26 #include "mlir/IR/Builders.h" 27 #include "mlir/IR/BuiltinOps.h" 28 #include "mlir/IR/BuiltinTypes.h" 29 #include "mlir/IR/MLIRContext.h" 30 #include "mlir/IR/Operation.h" 31 #include "mlir/IR/PatternMatch.h" 32 #include "mlir/IR/Types.h" 33 #include "mlir/Support/LogicalResult.h" 34 #include "mlir/Transforms/DialectConversion.h" 35 #include "mlir/Transforms/Passes.h" 36 #include "llvm/ADT/SetVector.h" 37 #include "llvm/IR/DerivedTypes.h" 38 #include "llvm/IR/Module.h" 39 #include "llvm/IR/Type.h" 40 #include "llvm/Support/Allocator.h" 41 #include "llvm/Support/ErrorHandling.h" 42 43 using namespace mlir; 44 using namespace mlir::LLVM; 45 using namespace mlir::linalg; 46 47 template <typename T> 48 static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) { 49 return LLVMPointerType::get( 50 lowering.convertType(containerType.getElementType())); 51 } 52 53 /// Convert the given range descriptor type to the LLVMIR dialect. 54 /// Range descriptor contains the range bounds and the step as 64-bit integers. 55 /// 56 /// struct { 57 /// int64_t min; 58 /// int64_t max; 59 /// int64_t step; 60 /// }; 61 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) { 62 auto *context = t.getContext(); 63 auto int64Ty = converter.convertType(IntegerType::get(context, 64)); 64 return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty}); 65 } 66 67 namespace { 68 // RangeOp creates a new range descriptor. 69 class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> { 70 public: 71 using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern; 72 73 LogicalResult 74 matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands, 75 ConversionPatternRewriter &rewriter) const override { 76 auto rangeDescriptorTy = convertRangeType( 77 rangeOp.getType().cast<RangeType>(), *getTypeConverter()); 78 79 ImplicitLocOpBuilder b(rangeOp->getLoc(), rewriter); 80 81 // Fill in an aggregate value of the descriptor. 82 RangeOpAdaptor adaptor(operands); 83 Value desc = b.create<LLVM::UndefOp>(rangeDescriptorTy); 84 desc = b.create<LLVM::InsertValueOp>(desc, adaptor.min(), 85 rewriter.getI64ArrayAttr(0)); 86 desc = b.create<LLVM::InsertValueOp>(desc, adaptor.max(), 87 rewriter.getI64ArrayAttr(1)); 88 desc = b.create<LLVM::InsertValueOp>(desc, adaptor.step(), 89 rewriter.getI64ArrayAttr(2)); 90 rewriter.replaceOp(rangeOp, desc); 91 return success(); 92 } 93 }; 94 95 // ReshapeOp creates a new view descriptor of the proper rank. 96 // For now, the only conversion supported is for target MemRef with static sizes 97 // and strides. 98 template <typename ReshapeOp> 99 class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> { 100 public: 101 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 102 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 103 104 LogicalResult 105 matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands, 106 ConversionPatternRewriter &rewriter) const override { 107 MemRefType dstType = reshapeOp.getResultType(); 108 109 if (!dstType.hasStaticShape()) 110 return failure(); 111 112 int64_t offset; 113 SmallVector<int64_t, 4> strides; 114 auto res = getStridesAndOffset(dstType, strides, offset); 115 if (failed(res) || llvm::any_of(strides, [](int64_t val) { 116 return ShapedType::isDynamicStrideOrOffset(val); 117 })) 118 return failure(); 119 120 ReshapeOpAdaptor adaptor(operands); 121 MemRefDescriptor baseDesc(adaptor.src()); 122 Location loc = reshapeOp->getLoc(); 123 auto desc = 124 MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(), 125 this->typeConverter->convertType(dstType)); 126 desc.setAllocatedPtr(rewriter, loc, baseDesc.allocatedPtr(rewriter, loc)); 127 desc.setAlignedPtr(rewriter, loc, baseDesc.alignedPtr(rewriter, loc)); 128 desc.setOffset(rewriter, loc, baseDesc.offset(rewriter, loc)); 129 for (auto en : llvm::enumerate(dstType.getShape())) 130 desc.setConstantSize(rewriter, loc, en.index(), en.value()); 131 for (auto en : llvm::enumerate(strides)) 132 desc.setConstantStride(rewriter, loc, en.index(), en.value()); 133 rewriter.replaceOp(reshapeOp, {desc}); 134 return success(); 135 } 136 }; 137 138 // YieldOp produces and LLVM::ReturnOp. 139 class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> { 140 public: 141 using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern; 142 143 LogicalResult 144 matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands, 145 ConversionPatternRewriter &rewriter) const override { 146 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands); 147 return success(); 148 } 149 }; 150 } // namespace 151 152 /// Populate the given list with patterns that convert from Linalg to LLVM. 153 void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter, 154 RewritePatternSet &patterns) { 155 patterns.add<RangeOpConversion, ReshapeOpConversion<ExpandShapeOp>, 156 ReshapeOpConversion<CollapseShapeOp>, YieldOpConversion>( 157 converter); 158 159 // Populate the type conversions for the linalg types. 160 converter.addConversion( 161 [&](RangeType type) { return convertRangeType(type, converter); }); 162 } 163 164 namespace { 165 struct ConvertLinalgToLLVMPass 166 : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> { 167 void runOnOperation() override; 168 }; 169 } // namespace 170 171 void ConvertLinalgToLLVMPass::runOnOperation() { 172 auto module = getOperation(); 173 174 // Convert to the LLVM IR dialect using the converter defined above. 175 RewritePatternSet patterns(&getContext()); 176 LLVMTypeConverter converter(&getContext()); 177 populateLinalgToLLVMConversionPatterns(converter, patterns); 178 179 LLVMConversionTarget target(getContext()); 180 target.addIllegalOp<RangeOp>(); 181 target.addLegalOp<ModuleOp, LLVM::DialectCastOp>(); 182 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 183 signalPassFailure(); 184 } 185 186 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() { 187 return std::make_unique<ConvertLinalgToLLVMPass>(); 188 } 189