1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" 13 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 16 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 20 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 21 #include "mlir/Dialect/Linalg/Passes.h" 22 #include "mlir/Dialect/SCF/SCF.h" 23 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 24 #include "mlir/IR/AffineExpr.h" 25 #include "mlir/IR/AffineMap.h" 26 #include "mlir/IR/Attributes.h" 27 #include "mlir/IR/Builders.h" 28 #include "mlir/IR/MLIRContext.h" 29 #include "mlir/IR/Module.h" 30 #include "mlir/IR/Operation.h" 31 #include "mlir/IR/PatternMatch.h" 32 #include "mlir/IR/StandardTypes.h" 33 #include "mlir/IR/Types.h" 34 #include "mlir/Support/LogicalResult.h" 35 #include "mlir/Transforms/DialectConversion.h" 36 #include "mlir/Transforms/Passes.h" 37 #include "llvm/ADT/SetVector.h" 38 #include "llvm/IR/DerivedTypes.h" 39 #include "llvm/IR/Module.h" 40 #include "llvm/IR/Type.h" 41 #include "llvm/Support/Allocator.h" 42 #include "llvm/Support/ErrorHandling.h" 43 44 using namespace mlir; 45 using namespace mlir::edsc; 46 using namespace mlir::edsc::intrinsics; 47 using namespace mlir::LLVM; 48 using namespace mlir::linalg; 49 50 using llvm_add = ValueBuilder<LLVM::AddOp>; 51 using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>; 52 using llvm_constant = ValueBuilder<LLVM::ConstantOp>; 53 using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>; 54 using llvm_gep = ValueBuilder<LLVM::GEPOp>; 55 using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>; 56 using llvm_call = OperationBuilder<LLVM::CallOp>; 57 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>; 58 using llvm_load = ValueBuilder<LLVM::LoadOp>; 59 using llvm_store = OperationBuilder<LLVM::StoreOp>; 60 using llvm_select = ValueBuilder<LLVM::SelectOp>; 61 using llvm_mul = ValueBuilder<LLVM::MulOp>; 62 using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>; 63 using llvm_sub = ValueBuilder<LLVM::SubOp>; 64 using llvm_undef = ValueBuilder<LLVM::UndefOp>; 65 using llvm_urem = ValueBuilder<LLVM::URemOp>; 66 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>; 67 using llvm_return = OperationBuilder<LLVM::ReturnOp>; 68 69 template <typename T> 70 static LLVMType getPtrToElementType(T containerType, 71 LLVMTypeConverter &lowering) { 72 return lowering.convertType(containerType.getElementType()) 73 .template cast<LLVMType>() 74 .getPointerTo(); 75 } 76 77 /// Convert the given range descriptor type to the LLVMIR dialect. 78 /// Range descriptor contains the range bounds and the step as 64-bit integers. 79 /// 80 /// struct { 81 /// int64_t min; 82 /// int64_t max; 83 /// int64_t step; 84 /// }; 85 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) { 86 auto *context = t.getContext(); 87 auto int64Ty = converter.convertType(IntegerType::get(64, context)) 88 .cast<LLVM::LLVMType>(); 89 return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); 90 } 91 92 namespace { 93 /// EDSC-compatible wrapper for MemRefDescriptor. 94 class BaseViewConversionHelper { 95 public: 96 BaseViewConversionHelper(Type type) 97 : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {} 98 99 BaseViewConversionHelper(Value v) : d(v) {} 100 101 /// Wrappers around MemRefDescriptor that use EDSC builder and location. 102 Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } 103 void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); } 104 Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); } 105 void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); } 106 Value offset() { return d.offset(rewriter(), loc()); } 107 void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); } 108 Value size(unsigned i) { return d.size(rewriter(), loc(), i); } 109 void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); } 110 void setConstantSize(unsigned i, int64_t v) { 111 d.setConstantSize(rewriter(), loc(), i, v); 112 } 113 Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); } 114 void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); } 115 void setConstantStride(unsigned i, int64_t v) { 116 d.setConstantStride(rewriter(), loc(), i, v); 117 } 118 119 operator Value() { return d; } 120 121 private: 122 OpBuilder &rewriter() { return ScopedContext::getBuilderRef(); } 123 Location loc() { return ScopedContext::getLocation(); } 124 125 MemRefDescriptor d; 126 }; 127 128 // RangeOp creates a new range descriptor. 129 class RangeOpConversion : public ConvertToLLVMPattern { 130 public: 131 explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 132 : ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {} 133 134 LogicalResult 135 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 136 ConversionPatternRewriter &rewriter) const override { 137 auto rangeOp = cast<RangeOp>(op); 138 auto rangeDescriptorTy = 139 convertRangeType(rangeOp.getType().cast<RangeType>(), typeConverter); 140 141 edsc::ScopedContext context(rewriter, op->getLoc()); 142 143 // Fill in an aggregate value of the descriptor. 144 RangeOpAdaptor adaptor(operands); 145 Value desc = llvm_undef(rangeDescriptorTy); 146 desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); 147 desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); 148 desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); 149 rewriter.replaceOp(op, desc); 150 return success(); 151 } 152 }; 153 154 // ReshapeOp creates a new view descriptor of the proper rank. 155 // For now, the only conversion supported is for target MemRef with static sizes 156 // and strides. 157 class ReshapeOpConversion : public ConvertToLLVMPattern { 158 public: 159 explicit ReshapeOpConversion(MLIRContext *context, 160 LLVMTypeConverter &lowering_) 161 : ConvertToLLVMPattern(ReshapeOp::getOperationName(), context, 162 lowering_) {} 163 164 LogicalResult 165 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 166 ConversionPatternRewriter &rewriter) const override { 167 auto reshapeOp = cast<ReshapeOp>(op); 168 MemRefType dstType = reshapeOp.getResultType(); 169 170 if (!dstType.hasStaticShape()) 171 return failure(); 172 173 int64_t offset; 174 SmallVector<int64_t, 4> strides; 175 auto res = getStridesAndOffset(dstType, strides, offset); 176 if (failed(res) || llvm::any_of(strides, [](int64_t val) { 177 return ShapedType::isDynamicStrideOrOffset(val); 178 })) 179 return failure(); 180 181 edsc::ScopedContext context(rewriter, op->getLoc()); 182 ReshapeOpAdaptor adaptor(operands); 183 BaseViewConversionHelper baseDesc(adaptor.src()); 184 BaseViewConversionHelper desc(typeConverter.convertType(dstType)); 185 desc.setAllocatedPtr(baseDesc.allocatedPtr()); 186 desc.setAlignedPtr(baseDesc.alignedPtr()); 187 desc.setOffset(baseDesc.offset()); 188 for (auto en : llvm::enumerate(dstType.getShape())) 189 desc.setConstantSize(en.index(), en.value()); 190 for (auto en : llvm::enumerate(strides)) 191 desc.setConstantStride(en.index(), en.value()); 192 rewriter.replaceOp(op, {desc}); 193 return success(); 194 } 195 }; 196 197 /// Conversion pattern that transforms a linalg.slice op into: 198 /// 1. An "undef" value for the ViewDescriptor. 199 /// 2. Updates to the ViewDescriptor to introduce the data ptr, offset, size 200 /// and stride corresponding to the region of memory within the bounds of 201 /// the parent view. 202 /// The linalg.slice op is replaced by the alloca'ed pointer. 203 class SliceOpConversion : public ConvertToLLVMPattern { 204 public: 205 explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 206 : ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {} 207 208 LogicalResult 209 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 210 ConversionPatternRewriter &rewriter) const override { 211 edsc::ScopedContext context(rewriter, op->getLoc()); 212 SliceOpAdaptor adaptor(operands); 213 BaseViewConversionHelper baseDesc(adaptor.view()); 214 215 auto sliceOp = cast<SliceOp>(op); 216 auto memRefType = sliceOp.getBaseViewType(); 217 auto int64Ty = typeConverter.convertType(rewriter.getIntegerType(64)) 218 .cast<LLVM::LLVMType>(); 219 220 BaseViewConversionHelper desc( 221 typeConverter.convertType(sliceOp.getShapedType())); 222 223 // TODO: extract sizes and emit asserts. 224 SmallVector<Value, 4> strides(memRefType.getRank()); 225 for (int i = 0, e = memRefType.getRank(); i < e; ++i) 226 strides[i] = baseDesc.stride(i); 227 228 auto pos = [&rewriter](ArrayRef<int64_t> values) { 229 return rewriter.getI64ArrayAttr(values); 230 }; 231 232 // Compute base offset. 233 Value baseOffset = baseDesc.offset(); 234 for (int i = 0, e = memRefType.getRank(); i < e; ++i) { 235 Value indexing = adaptor.indexings()[i]; 236 Value min = indexing; 237 if (sliceOp.indexing(i).getType().isa<RangeType>()) 238 min = llvm_extractvalue(int64Ty, indexing, pos(0)); 239 baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i])); 240 } 241 242 // Insert the base and aligned pointers. 243 desc.setAllocatedPtr(baseDesc.allocatedPtr()); 244 desc.setAlignedPtr(baseDesc.alignedPtr()); 245 246 // Insert base offset. 247 desc.setOffset(baseOffset); 248 249 // Corner case, no sizes or strides: early return the descriptor. 250 if (sliceOp.getShapedType().getRank() == 0) 251 return rewriter.replaceOp(op, {desc}), success(); 252 253 Value zero = llvm_constant( 254 int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); 255 // Compute and insert view sizes (max - min along the range) and strides. 256 // Skip the non-range operands as they will be projected away from the view. 257 int numNewDims = 0; 258 for (auto en : llvm::enumerate(sliceOp.indexings())) { 259 Value indexing = en.value(); 260 if (indexing.getType().isa<RangeType>()) { 261 int rank = en.index(); 262 Value rangeDescriptor = adaptor.indexings()[rank]; 263 Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0)); 264 Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1)); 265 Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2)); 266 Value baseSize = baseDesc.size(rank); 267 268 // Bound upper by base view upper bound. 269 max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, 270 baseSize); 271 Value size = llvm_sub(max, min); 272 // Bound lower by zero. 273 size = 274 llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); 275 Value stride = llvm_mul(strides[rank], step); 276 desc.setSize(numNewDims, size); 277 desc.setStride(numNewDims, stride); 278 ++numNewDims; 279 } 280 } 281 282 rewriter.replaceOp(op, {desc}); 283 return success(); 284 } 285 }; 286 287 // YieldOp produces and LLVM::ReturnOp. 288 class YieldOpConversion : public ConvertToLLVMPattern { 289 public: 290 explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 291 : ConvertToLLVMPattern(linalg::YieldOp::getOperationName(), context, 292 lowering_) {} 293 294 LogicalResult 295 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 296 ConversionPatternRewriter &rewriter) const override { 297 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands); 298 return success(); 299 } 300 }; 301 } // namespace 302 303 /// Populate the given list with patterns that convert from Linalg to LLVM. 304 void mlir::populateLinalgToLLVMConversionPatterns( 305 LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 306 MLIRContext *ctx) { 307 patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion, 308 YieldOpConversion>(ctx, converter); 309 310 // Populate the type conversions for the linalg types. 311 converter.addConversion( 312 [&](RangeType type) { return convertRangeType(type, converter); }); 313 } 314 315 namespace { 316 struct ConvertLinalgToLLVMPass 317 : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> { 318 void runOnOperation() override; 319 }; 320 } // namespace 321 322 void ConvertLinalgToLLVMPass::runOnOperation() { 323 auto module = getOperation(); 324 325 // Convert to the LLVM IR dialect using the converter defined above. 326 OwningRewritePatternList patterns; 327 LLVMTypeConverter converter(&getContext()); 328 populateAffineToStdConversionPatterns(patterns, &getContext()); 329 populateLoopToStdConversionPatterns(patterns, &getContext()); 330 populateStdToLLVMConversionPatterns(converter, patterns); 331 populateVectorToSCFConversionPatterns(patterns, &getContext()); 332 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 333 populateVectorToLLVMConversionPatterns(converter, patterns); 334 populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); 335 336 LLVMConversionTarget target(getContext()); 337 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 338 if (failed(applyFullConversion(module, target, std::move(patterns)))) 339 signalPassFailure(); 340 } 341 342 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() { 343 return std::make_unique<ConvertLinalgToLLVMPass>(); 344 } 345