1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" 13 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 16 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 20 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 21 #include "mlir/Dialect/Linalg/Passes.h" 22 #include "mlir/Dialect/SCF/SCF.h" 23 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 24 #include "mlir/IR/AffineExpr.h" 25 #include "mlir/IR/AffineMap.h" 26 #include "mlir/IR/Attributes.h" 27 #include "mlir/IR/Builders.h" 28 #include "mlir/IR/BuiltinOps.h" 29 #include "mlir/IR/BuiltinTypes.h" 30 #include "mlir/IR/MLIRContext.h" 31 #include "mlir/IR/Operation.h" 32 #include "mlir/IR/PatternMatch.h" 33 #include "mlir/IR/Types.h" 34 #include "mlir/Support/LogicalResult.h" 35 #include "mlir/Transforms/DialectConversion.h" 36 #include "mlir/Transforms/Passes.h" 37 #include "llvm/ADT/SetVector.h" 38 #include "llvm/IR/DerivedTypes.h" 39 #include "llvm/IR/Module.h" 40 #include "llvm/IR/Type.h" 41 #include "llvm/Support/Allocator.h" 42 #include "llvm/Support/ErrorHandling.h" 43 44 using namespace mlir; 45 using namespace mlir::edsc; 46 using namespace mlir::edsc::intrinsics; 47 using namespace mlir::LLVM; 48 using namespace mlir::linalg; 49 50 using llvm_add = ValueBuilder<LLVM::AddOp>; 51 using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>; 52 using llvm_constant = ValueBuilder<LLVM::ConstantOp>; 53 using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>; 54 using llvm_gep = ValueBuilder<LLVM::GEPOp>; 55 using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>; 56 using llvm_call = OperationBuilder<LLVM::CallOp>; 57 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>; 58 using llvm_load = ValueBuilder<LLVM::LoadOp>; 59 using llvm_store = OperationBuilder<LLVM::StoreOp>; 60 using llvm_select = ValueBuilder<LLVM::SelectOp>; 61 using llvm_mul = ValueBuilder<LLVM::MulOp>; 62 using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>; 63 using llvm_sub = ValueBuilder<LLVM::SubOp>; 64 using llvm_undef = ValueBuilder<LLVM::UndefOp>; 65 using llvm_urem = ValueBuilder<LLVM::URemOp>; 66 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>; 67 using llvm_return = OperationBuilder<LLVM::ReturnOp>; 68 69 template <typename T> 70 static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) { 71 return LLVMPointerType::get( 72 lowering.convertType(containerType.getElementType())); 73 } 74 75 /// Convert the given range descriptor type to the LLVMIR dialect. 76 /// Range descriptor contains the range bounds and the step as 64-bit integers. 77 /// 78 /// struct { 79 /// int64_t min; 80 /// int64_t max; 81 /// int64_t step; 82 /// }; 83 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) { 84 auto *context = t.getContext(); 85 auto int64Ty = converter.convertType(IntegerType::get(context, 64)); 86 return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty}); 87 } 88 89 namespace { 90 /// EDSC-compatible wrapper for MemRefDescriptor. 91 class BaseViewConversionHelper { 92 public: 93 BaseViewConversionHelper(Type type) 94 : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {} 95 96 BaseViewConversionHelper(Value v) : d(v) {} 97 98 /// Wrappers around MemRefDescriptor that use EDSC builder and location. 99 Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } 100 void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); } 101 Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); } 102 void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); } 103 Value offset() { return d.offset(rewriter(), loc()); } 104 void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); } 105 Value size(unsigned i) { return d.size(rewriter(), loc(), i); } 106 void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); } 107 void setConstantSize(unsigned i, int64_t v) { 108 d.setConstantSize(rewriter(), loc(), i, v); 109 } 110 Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); } 111 void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); } 112 void setConstantStride(unsigned i, int64_t v) { 113 d.setConstantStride(rewriter(), loc(), i, v); 114 } 115 116 operator Value() { return d; } 117 118 private: 119 OpBuilder &rewriter() { return ScopedContext::getBuilderRef(); } 120 Location loc() { return ScopedContext::getLocation(); } 121 122 MemRefDescriptor d; 123 }; 124 125 // RangeOp creates a new range descriptor. 126 class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> { 127 public: 128 using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern; 129 130 LogicalResult 131 matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands, 132 ConversionPatternRewriter &rewriter) const override { 133 auto rangeDescriptorTy = convertRangeType( 134 rangeOp.getType().cast<RangeType>(), *getTypeConverter()); 135 136 edsc::ScopedContext context(rewriter, rangeOp->getLoc()); 137 138 // Fill in an aggregate value of the descriptor. 139 RangeOpAdaptor adaptor(operands); 140 Value desc = llvm_undef(rangeDescriptorTy); 141 desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); 142 desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); 143 desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); 144 rewriter.replaceOp(rangeOp, desc); 145 return success(); 146 } 147 }; 148 149 // ReshapeOp creates a new view descriptor of the proper rank. 150 // For now, the only conversion supported is for target MemRef with static sizes 151 // and strides. 152 class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> { 153 public: 154 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 155 156 LogicalResult 157 matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands, 158 ConversionPatternRewriter &rewriter) const override { 159 MemRefType dstType = reshapeOp.getResultType(); 160 161 if (!dstType.hasStaticShape()) 162 return failure(); 163 164 int64_t offset; 165 SmallVector<int64_t, 4> strides; 166 auto res = getStridesAndOffset(dstType, strides, offset); 167 if (failed(res) || llvm::any_of(strides, [](int64_t val) { 168 return ShapedType::isDynamicStrideOrOffset(val); 169 })) 170 return failure(); 171 172 edsc::ScopedContext context(rewriter, reshapeOp->getLoc()); 173 ReshapeOpAdaptor adaptor(operands); 174 BaseViewConversionHelper baseDesc(adaptor.src()); 175 BaseViewConversionHelper desc(typeConverter->convertType(dstType)); 176 desc.setAllocatedPtr(baseDesc.allocatedPtr()); 177 desc.setAlignedPtr(baseDesc.alignedPtr()); 178 desc.setOffset(baseDesc.offset()); 179 for (auto en : llvm::enumerate(dstType.getShape())) 180 desc.setConstantSize(en.index(), en.value()); 181 for (auto en : llvm::enumerate(strides)) 182 desc.setConstantStride(en.index(), en.value()); 183 rewriter.replaceOp(reshapeOp, {desc}); 184 return success(); 185 } 186 }; 187 188 // YieldOp produces and LLVM::ReturnOp. 189 class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> { 190 public: 191 using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern; 192 193 LogicalResult 194 matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands, 195 ConversionPatternRewriter &rewriter) const override { 196 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands); 197 return success(); 198 } 199 }; 200 } // namespace 201 202 /// Populate the given list with patterns that convert from Linalg to LLVM. 203 void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter, 204 RewritePatternSet &patterns) { 205 patterns.add<RangeOpConversion, ReshapeOpConversion, YieldOpConversion>( 206 converter); 207 208 // Populate the type conversions for the linalg types. 209 converter.addConversion( 210 [&](RangeType type) { return convertRangeType(type, converter); }); 211 } 212 213 namespace { 214 struct ConvertLinalgToLLVMPass 215 : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> { 216 void runOnOperation() override; 217 }; 218 } // namespace 219 220 void ConvertLinalgToLLVMPass::runOnOperation() { 221 auto module = getOperation(); 222 223 // Convert to the LLVM IR dialect using the converter defined above. 224 RewritePatternSet patterns(&getContext()); 225 LLVMTypeConverter converter(&getContext()); 226 populateLinalgToLLVMConversionPatterns(converter, patterns); 227 228 LLVMConversionTarget target(getContext()); 229 target.addIllegalOp<RangeOp>(); 230 target.addLegalOp<ModuleOp, LLVM::DialectCastOp>(); 231 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 232 signalPassFailure(); 233 } 234 235 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() { 236 return std::make_unique<ConvertLinalgToLLVMPass>(); 237 } 238