1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" 13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 14 #include "mlir/Conversion/LLVMCommon/Pattern.h" 15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 16 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 17 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" 18 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 19 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/Linalg/Passes.h" 23 #include "mlir/Dialect/SCF/SCF.h" 24 #include "mlir/IR/AffineExpr.h" 25 #include "mlir/IR/AffineMap.h" 26 #include "mlir/IR/Attributes.h" 27 #include "mlir/IR/Builders.h" 28 #include "mlir/IR/BuiltinOps.h" 29 #include "mlir/IR/BuiltinTypes.h" 30 #include "mlir/IR/MLIRContext.h" 31 #include "mlir/IR/Operation.h" 32 #include "mlir/IR/PatternMatch.h" 33 #include "mlir/IR/Types.h" 34 #include "mlir/Support/LogicalResult.h" 35 #include "mlir/Transforms/DialectConversion.h" 36 #include "mlir/Transforms/Passes.h" 37 #include "llvm/ADT/SetVector.h" 38 #include "llvm/IR/DerivedTypes.h" 39 #include "llvm/IR/Module.h" 40 #include "llvm/IR/Type.h" 41 #include "llvm/Support/Allocator.h" 42 #include "llvm/Support/ErrorHandling.h" 43 44 using namespace mlir; 45 using namespace mlir::LLVM; 46 using namespace mlir::linalg; 47 48 template <typename T> 49 static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) { 50 return LLVMPointerType::get( 51 lowering.convertType(containerType.getElementType())); 52 } 53 54 namespace { 55 // YieldOp produces and LLVM::ReturnOp. 56 class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> { 57 public: 58 using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern; 59 60 LogicalResult 61 matchAndRewrite(linalg::YieldOp op, OpAdaptor adaptor, 62 ConversionPatternRewriter &rewriter) const override { 63 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands()); 64 return success(); 65 } 66 }; 67 } // namespace 68 69 /// Populate the given list with patterns that convert from Linalg to LLVM. 70 void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter, 71 RewritePatternSet &patterns) { 72 patterns.add<YieldOpConversion>(converter); 73 } 74 75 namespace { 76 struct ConvertLinalgToLLVMPass 77 : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> { 78 void runOnOperation() override; 79 }; 80 } // namespace 81 82 void ConvertLinalgToLLVMPass::runOnOperation() { 83 auto module = getOperation(); 84 85 // Convert to the LLVM IR dialect using the converter defined above. 86 RewritePatternSet patterns(&getContext()); 87 LLVMTypeConverter converter(&getContext()); 88 populateLinalgToLLVMConversionPatterns(converter, patterns); 89 populateMemRefToLLVMConversionPatterns(converter, patterns); 90 91 LLVMConversionTarget target(getContext()); 92 target.addLegalOp<ModuleOp>(); 93 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 94 signalPassFailure(); 95 } 96 97 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() { 98 return std::make_unique<ConvertLinalgToLLVMPass>(); 99 } 100