1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" 13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 14 #include "mlir/Conversion/LLVMCommon/Pattern.h" 15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 16 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 17 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 18 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 19 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 22 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 23 #include "mlir/Dialect/Linalg/Passes.h" 24 #include "mlir/Dialect/SCF/SCF.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/IR/Attributes.h" 28 #include "mlir/IR/Builders.h" 29 #include "mlir/IR/BuiltinOps.h" 30 #include "mlir/IR/BuiltinTypes.h" 31 #include "mlir/IR/MLIRContext.h" 32 #include "mlir/IR/Operation.h" 33 #include "mlir/IR/PatternMatch.h" 34 #include "mlir/IR/Types.h" 35 #include "mlir/Support/LogicalResult.h" 36 #include "mlir/Transforms/DialectConversion.h" 37 #include "mlir/Transforms/Passes.h" 38 #include "llvm/ADT/SetVector.h" 39 #include "llvm/IR/DerivedTypes.h" 40 #include "llvm/IR/Module.h" 41 #include "llvm/IR/Type.h" 42 #include "llvm/Support/Allocator.h" 43 #include "llvm/Support/ErrorHandling.h" 44 45 using namespace mlir; 46 using namespace mlir::LLVM; 47 using namespace mlir::linalg; 48 49 template <typename T> 50 static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) { 51 return LLVMPointerType::get( 52 lowering.convertType(containerType.getElementType())); 53 } 54 55 /// Convert the given range descriptor type to the LLVMIR dialect. 56 /// Range descriptor contains the range bounds and the step as 64-bit integers. 57 /// 58 /// struct { 59 /// int64_t min; 60 /// int64_t max; 61 /// int64_t step; 62 /// }; 63 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) { 64 auto *context = t.getContext(); 65 auto int64Ty = converter.convertType(IntegerType::get(context, 64)); 66 return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty}); 67 } 68 69 namespace { 70 // RangeOp creates a new range descriptor. 71 class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> { 72 public: 73 using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern; 74 75 LogicalResult 76 matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands, 77 ConversionPatternRewriter &rewriter) const override { 78 auto rangeDescriptorTy = convertRangeType( 79 rangeOp.getType().cast<RangeType>(), *getTypeConverter()); 80 81 ImplicitLocOpBuilder b(rangeOp->getLoc(), rewriter); 82 83 // Fill in an aggregate value of the descriptor. 84 RangeOpAdaptor adaptor(operands); 85 Value desc = b.create<LLVM::UndefOp>(rangeDescriptorTy); 86 desc = b.create<LLVM::InsertValueOp>(desc, adaptor.min(), 87 rewriter.getI64ArrayAttr(0)); 88 desc = b.create<LLVM::InsertValueOp>(desc, adaptor.max(), 89 rewriter.getI64ArrayAttr(1)); 90 desc = b.create<LLVM::InsertValueOp>(desc, adaptor.step(), 91 rewriter.getI64ArrayAttr(2)); 92 rewriter.replaceOp(rangeOp, desc); 93 return success(); 94 } 95 }; 96 97 98 // YieldOp produces and LLVM::ReturnOp. 99 class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> { 100 public: 101 using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern; 102 103 LogicalResult 104 matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands, 105 ConversionPatternRewriter &rewriter) const override { 106 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands); 107 return success(); 108 } 109 }; 110 } // namespace 111 112 /// Populate the given list with patterns that convert from Linalg to LLVM. 113 void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter, 114 RewritePatternSet &patterns) { 115 patterns.add<RangeOpConversion, YieldOpConversion>(converter); 116 117 // Populate the type conversions for the linalg types. 118 converter.addConversion( 119 [&](RangeType type) { return convertRangeType(type, converter); }); 120 } 121 122 namespace { 123 struct ConvertLinalgToLLVMPass 124 : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> { 125 void runOnOperation() override; 126 }; 127 } // namespace 128 129 void ConvertLinalgToLLVMPass::runOnOperation() { 130 auto module = getOperation(); 131 132 // Convert to the LLVM IR dialect using the converter defined above. 133 RewritePatternSet patterns(&getContext()); 134 LLVMTypeConverter converter(&getContext()); 135 populateLinalgToLLVMConversionPatterns(converter, patterns); 136 populateMemRefToLLVMConversionPatterns(converter, patterns); 137 138 LLVMConversionTarget target(getContext()); 139 target.addIllegalOp<RangeOp>(); 140 target.addLegalOp<ModuleOp>(); 141 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 142 signalPassFailure(); 143 } 144 145 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() { 146 return std::make_unique<ConvertLinalgToLLVMPass>(); 147 } 148