1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" 13 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 16 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 20 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 21 #include "mlir/Dialect/Linalg/Passes.h" 22 #include "mlir/Dialect/SCF/SCF.h" 23 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 24 #include "mlir/IR/AffineExpr.h" 25 #include "mlir/IR/AffineMap.h" 26 #include "mlir/IR/Attributes.h" 27 #include "mlir/IR/Builders.h" 28 #include "mlir/IR/BuiltinOps.h" 29 #include "mlir/IR/BuiltinTypes.h" 30 #include "mlir/IR/MLIRContext.h" 31 #include "mlir/IR/Operation.h" 32 #include "mlir/IR/PatternMatch.h" 33 #include "mlir/IR/Types.h" 34 #include "mlir/Support/LogicalResult.h" 35 #include "mlir/Transforms/DialectConversion.h" 36 #include "mlir/Transforms/Passes.h" 37 #include "llvm/ADT/SetVector.h" 38 #include "llvm/IR/DerivedTypes.h" 39 #include "llvm/IR/Module.h" 40 #include "llvm/IR/Type.h" 41 #include "llvm/Support/Allocator.h" 42 #include "llvm/Support/ErrorHandling.h" 43 44 using namespace mlir; 45 using namespace mlir::edsc; 46 using namespace mlir::edsc::intrinsics; 47 using namespace mlir::LLVM; 48 using namespace mlir::linalg; 49 50 using llvm_add = ValueBuilder<LLVM::AddOp>; 51 using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>; 52 using llvm_constant = ValueBuilder<LLVM::ConstantOp>; 53 using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>; 54 using llvm_gep = ValueBuilder<LLVM::GEPOp>; 55 using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>; 56 using llvm_call = OperationBuilder<LLVM::CallOp>; 57 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>; 58 using llvm_load = ValueBuilder<LLVM::LoadOp>; 59 using llvm_store = OperationBuilder<LLVM::StoreOp>; 60 using llvm_select = ValueBuilder<LLVM::SelectOp>; 61 using llvm_mul = ValueBuilder<LLVM::MulOp>; 62 using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>; 63 using llvm_sub = ValueBuilder<LLVM::SubOp>; 64 using llvm_undef = ValueBuilder<LLVM::UndefOp>; 65 using llvm_urem = ValueBuilder<LLVM::URemOp>; 66 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>; 67 using llvm_return = OperationBuilder<LLVM::ReturnOp>; 68 69 template <typename T> 70 static LLVMType getPtrToElementType(T containerType, 71 LLVMTypeConverter &lowering) { 72 return lowering.convertType(containerType.getElementType()) 73 .template cast<LLVMType>() 74 .getPointerTo(); 75 } 76 77 /// Convert the given range descriptor type to the LLVMIR dialect. 78 /// Range descriptor contains the range bounds and the step as 64-bit integers. 79 /// 80 /// struct { 81 /// int64_t min; 82 /// int64_t max; 83 /// int64_t step; 84 /// }; 85 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) { 86 auto *context = t.getContext(); 87 auto int64Ty = converter.convertType(IntegerType::get(context, 64)) 88 .cast<LLVM::LLVMType>(); 89 return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty}); 90 } 91 92 namespace { 93 /// EDSC-compatible wrapper for MemRefDescriptor. 94 class BaseViewConversionHelper { 95 public: 96 BaseViewConversionHelper(Type type) 97 : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {} 98 99 BaseViewConversionHelper(Value v) : d(v) {} 100 101 /// Wrappers around MemRefDescriptor that use EDSC builder and location. 102 Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } 103 void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); } 104 Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); } 105 void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); } 106 Value offset() { return d.offset(rewriter(), loc()); } 107 void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); } 108 Value size(unsigned i) { return d.size(rewriter(), loc(), i); } 109 void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); } 110 void setConstantSize(unsigned i, int64_t v) { 111 d.setConstantSize(rewriter(), loc(), i, v); 112 } 113 Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); } 114 void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); } 115 void setConstantStride(unsigned i, int64_t v) { 116 d.setConstantStride(rewriter(), loc(), i, v); 117 } 118 119 operator Value() { return d; } 120 121 private: 122 OpBuilder &rewriter() { return ScopedContext::getBuilderRef(); } 123 Location loc() { return ScopedContext::getLocation(); } 124 125 MemRefDescriptor d; 126 }; 127 128 // RangeOp creates a new range descriptor. 129 class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> { 130 public: 131 using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern; 132 133 LogicalResult 134 matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands, 135 ConversionPatternRewriter &rewriter) const override { 136 auto rangeDescriptorTy = convertRangeType( 137 rangeOp.getType().cast<RangeType>(), *getTypeConverter()); 138 139 edsc::ScopedContext context(rewriter, rangeOp->getLoc()); 140 141 // Fill in an aggregate value of the descriptor. 142 RangeOpAdaptor adaptor(operands); 143 Value desc = llvm_undef(rangeDescriptorTy); 144 desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); 145 desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); 146 desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); 147 rewriter.replaceOp(rangeOp, desc); 148 return success(); 149 } 150 }; 151 152 // ReshapeOp creates a new view descriptor of the proper rank. 153 // For now, the only conversion supported is for target MemRef with static sizes 154 // and strides. 155 class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> { 156 public: 157 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 158 159 LogicalResult 160 matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands, 161 ConversionPatternRewriter &rewriter) const override { 162 MemRefType dstType = reshapeOp.getResultType(); 163 164 if (!dstType.hasStaticShape()) 165 return failure(); 166 167 int64_t offset; 168 SmallVector<int64_t, 4> strides; 169 auto res = getStridesAndOffset(dstType, strides, offset); 170 if (failed(res) || llvm::any_of(strides, [](int64_t val) { 171 return ShapedType::isDynamicStrideOrOffset(val); 172 })) 173 return failure(); 174 175 edsc::ScopedContext context(rewriter, reshapeOp->getLoc()); 176 ReshapeOpAdaptor adaptor(operands); 177 BaseViewConversionHelper baseDesc(adaptor.src()); 178 BaseViewConversionHelper desc(typeConverter->convertType(dstType)); 179 desc.setAllocatedPtr(baseDesc.allocatedPtr()); 180 desc.setAlignedPtr(baseDesc.alignedPtr()); 181 desc.setOffset(baseDesc.offset()); 182 for (auto en : llvm::enumerate(dstType.getShape())) 183 desc.setConstantSize(en.index(), en.value()); 184 for (auto en : llvm::enumerate(strides)) 185 desc.setConstantStride(en.index(), en.value()); 186 rewriter.replaceOp(reshapeOp, {desc}); 187 return success(); 188 } 189 }; 190 191 /// Conversion pattern that transforms a linalg.slice op into: 192 /// 1. An "undef" value for the ViewDescriptor. 193 /// 2. Updates to the ViewDescriptor to introduce the data ptr, offset, size 194 /// and stride corresponding to the region of memory within the bounds of 195 /// the parent view. 196 /// The linalg.slice op is replaced by the alloca'ed pointer. 197 class SliceOpConversion : public ConvertOpToLLVMPattern<SliceOp> { 198 public: 199 using ConvertOpToLLVMPattern<SliceOp>::ConvertOpToLLVMPattern; 200 201 LogicalResult 202 matchAndRewrite(SliceOp sliceOp, ArrayRef<Value> operands, 203 ConversionPatternRewriter &rewriter) const override { 204 edsc::ScopedContext context(rewriter, sliceOp->getLoc()); 205 SliceOpAdaptor adaptor(operands); 206 BaseViewConversionHelper baseDesc(adaptor.view()); 207 208 auto memRefType = sliceOp.getBaseViewType(); 209 auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64)) 210 .cast<LLVM::LLVMType>(); 211 212 BaseViewConversionHelper desc( 213 typeConverter->convertType(sliceOp.getShapedType())); 214 215 // TODO: extract sizes and emit asserts. 216 SmallVector<Value, 4> strides(memRefType.getRank()); 217 for (int i = 0, e = memRefType.getRank(); i < e; ++i) 218 strides[i] = baseDesc.stride(i); 219 220 auto pos = [&rewriter](ArrayRef<int64_t> values) { 221 return rewriter.getI64ArrayAttr(values); 222 }; 223 224 // Compute base offset. 225 Value baseOffset = baseDesc.offset(); 226 for (int i = 0, e = memRefType.getRank(); i < e; ++i) { 227 Value indexing = adaptor.indexings()[i]; 228 Value min = indexing; 229 if (sliceOp.indexing(i).getType().isa<RangeType>()) 230 min = llvm_extractvalue(int64Ty, indexing, pos(0)); 231 baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i])); 232 } 233 234 // Insert the base and aligned pointers. 235 desc.setAllocatedPtr(baseDesc.allocatedPtr()); 236 desc.setAlignedPtr(baseDesc.alignedPtr()); 237 238 // Insert base offset. 239 desc.setOffset(baseOffset); 240 241 // Corner case, no sizes or strides: early return the descriptor. 242 if (sliceOp.getShapedType().getRank() == 0) 243 return rewriter.replaceOp(sliceOp, {desc}), success(); 244 245 Value zero = llvm_constant( 246 int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); 247 // Compute and insert view sizes (max - min along the range) and strides. 248 // Skip the non-range operands as they will be projected away from the view. 249 int numNewDims = 0; 250 for (auto en : llvm::enumerate(sliceOp.indexings())) { 251 Value indexing = en.value(); 252 if (indexing.getType().isa<RangeType>()) { 253 int rank = en.index(); 254 Value rangeDescriptor = adaptor.indexings()[rank]; 255 Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0)); 256 Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1)); 257 Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2)); 258 Value baseSize = baseDesc.size(rank); 259 260 // Bound upper by base view upper bound. 261 max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, 262 baseSize); 263 Value size = llvm_sub(max, min); 264 // Bound lower by zero. 265 size = 266 llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); 267 Value stride = llvm_mul(strides[rank], step); 268 desc.setSize(numNewDims, size); 269 desc.setStride(numNewDims, stride); 270 ++numNewDims; 271 } 272 } 273 274 rewriter.replaceOp(sliceOp, {desc}); 275 return success(); 276 } 277 }; 278 279 // YieldOp produces and LLVM::ReturnOp. 280 class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> { 281 public: 282 using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern; 283 284 LogicalResult 285 matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands, 286 ConversionPatternRewriter &rewriter) const override { 287 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands); 288 return success(); 289 } 290 }; 291 } // namespace 292 293 /// Populate the given list with patterns that convert from Linalg to LLVM. 294 void mlir::populateLinalgToLLVMConversionPatterns( 295 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 296 patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion, 297 YieldOpConversion>(converter); 298 299 // Populate the type conversions for the linalg types. 300 converter.addConversion( 301 [&](RangeType type) { return convertRangeType(type, converter); }); 302 } 303 304 namespace { 305 struct ConvertLinalgToLLVMPass 306 : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> { 307 void runOnOperation() override; 308 }; 309 } // namespace 310 311 void ConvertLinalgToLLVMPass::runOnOperation() { 312 auto module = getOperation(); 313 314 // Convert to the LLVM IR dialect using the converter defined above. 315 OwningRewritePatternList patterns; 316 LLVMTypeConverter converter(&getContext()); 317 populateAffineToStdConversionPatterns(patterns, &getContext()); 318 populateLoopToStdConversionPatterns(patterns, &getContext()); 319 populateStdToLLVMConversionPatterns(converter, patterns); 320 populateVectorToSCFConversionPatterns(patterns, &getContext()); 321 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 322 populateVectorToLLVMConversionPatterns(converter, patterns); 323 populateLinalgToLLVMConversionPatterns(converter, patterns); 324 325 LLVMConversionTarget target(getContext()); 326 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 327 if (failed(applyFullConversion(module, target, std::move(patterns)))) 328 signalPassFailure(); 329 } 330 331 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() { 332 return std::make_unique<ConvertLinalgToLLVMPass>(); 333 } 334