1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" 13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 14 #include "mlir/Conversion/LLVMCommon/Pattern.h" 15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 16 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 17 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 18 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 19 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 22 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 23 #include "mlir/Dialect/Linalg/Passes.h" 24 #include "mlir/Dialect/SCF/SCF.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/IR/Attributes.h" 28 #include "mlir/IR/Builders.h" 29 #include "mlir/IR/BuiltinOps.h" 30 #include "mlir/IR/BuiltinTypes.h" 31 #include "mlir/IR/MLIRContext.h" 32 #include "mlir/IR/Operation.h" 33 #include "mlir/IR/PatternMatch.h" 34 #include "mlir/IR/Types.h" 35 #include "mlir/Support/LogicalResult.h" 36 #include "mlir/Transforms/DialectConversion.h" 37 #include "mlir/Transforms/Passes.h" 38 #include "llvm/ADT/SetVector.h" 39 #include "llvm/IR/DerivedTypes.h" 40 #include "llvm/IR/Module.h" 41 #include "llvm/IR/Type.h" 42 #include "llvm/Support/Allocator.h" 43 #include "llvm/Support/ErrorHandling.h" 44 45 using namespace mlir; 46 using namespace mlir::LLVM; 47 using namespace mlir::linalg; 48 49 template <typename T> 50 static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) { 51 return LLVMPointerType::get( 52 lowering.convertType(containerType.getElementType())); 53 } 54 55 /// Convert the given range descriptor type to the LLVMIR dialect. 56 /// Range descriptor contains the range bounds and the step as 64-bit integers. 57 /// 58 /// struct { 59 /// int64_t min; 60 /// int64_t max; 61 /// int64_t step; 62 /// }; 63 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) { 64 auto *context = t.getContext(); 65 auto int64Ty = converter.convertType(IntegerType::get(context, 64)); 66 return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty}); 67 } 68 69 namespace { 70 // RangeOp creates a new range descriptor. 71 class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> { 72 public: 73 using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern; 74 75 LogicalResult 76 matchAndRewrite(RangeOp rangeOp, OpAdaptor adaptor, 77 ConversionPatternRewriter &rewriter) const override { 78 auto rangeDescriptorTy = convertRangeType( 79 rangeOp.getType().cast<RangeType>(), *getTypeConverter()); 80 81 ImplicitLocOpBuilder b(rangeOp->getLoc(), rewriter); 82 83 // Fill in an aggregate value of the descriptor. 84 Value desc = b.create<LLVM::UndefOp>(rangeDescriptorTy); 85 desc = b.create<LLVM::InsertValueOp>(desc, adaptor.min(), 86 rewriter.getI64ArrayAttr(0)); 87 desc = b.create<LLVM::InsertValueOp>(desc, adaptor.max(), 88 rewriter.getI64ArrayAttr(1)); 89 desc = b.create<LLVM::InsertValueOp>(desc, adaptor.step(), 90 rewriter.getI64ArrayAttr(2)); 91 rewriter.replaceOp(rangeOp, desc); 92 return success(); 93 } 94 }; 95 96 97 // YieldOp produces and LLVM::ReturnOp. 98 class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> { 99 public: 100 using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern; 101 102 LogicalResult 103 matchAndRewrite(linalg::YieldOp op, OpAdaptor adaptor, 104 ConversionPatternRewriter &rewriter) const override { 105 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands()); 106 return success(); 107 } 108 }; 109 } // namespace 110 111 /// Populate the given list with patterns that convert from Linalg to LLVM. 112 void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter, 113 RewritePatternSet &patterns) { 114 patterns.add<RangeOpConversion, YieldOpConversion>(converter); 115 116 // Populate the type conversions for the linalg types. 117 converter.addConversion( 118 [&](RangeType type) { return convertRangeType(type, converter); }); 119 } 120 121 namespace { 122 struct ConvertLinalgToLLVMPass 123 : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> { 124 void runOnOperation() override; 125 }; 126 } // namespace 127 128 void ConvertLinalgToLLVMPass::runOnOperation() { 129 auto module = getOperation(); 130 131 // Convert to the LLVM IR dialect using the converter defined above. 132 RewritePatternSet patterns(&getContext()); 133 LLVMTypeConverter converter(&getContext()); 134 populateLinalgToLLVMConversionPatterns(converter, patterns); 135 populateMemRefToLLVMConversionPatterns(converter, patterns); 136 137 LLVMConversionTarget target(getContext()); 138 target.addIllegalOp<RangeOp>(); 139 target.addLegalOp<ModuleOp>(); 140 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 141 signalPassFailure(); 142 } 143 144 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() { 145 return std::make_unique<ConvertLinalgToLLVMPass>(); 146 } 147