1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 10 11 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" 12 #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" 13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 15 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/Attributes.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/MLIRContext.h" 26 #include "mlir/IR/Module.h" 27 #include "mlir/IR/Operation.h" 28 #include "mlir/IR/PatternMatch.h" 29 #include "mlir/IR/StandardTypes.h" 30 #include "mlir/IR/Types.h" 31 #include "mlir/Pass/Pass.h" 32 #include "mlir/Pass/PassManager.h" 33 #include "mlir/Support/LogicalResult.h" 34 #include "mlir/Transforms/DialectConversion.h" 35 #include "mlir/Transforms/Passes.h" 36 #include "llvm/ADT/SetVector.h" 37 #include "llvm/IR/DerivedTypes.h" 38 #include "llvm/IR/Module.h" 39 #include "llvm/IR/Type.h" 40 #include "llvm/Support/Allocator.h" 41 #include "llvm/Support/ErrorHandling.h" 42 43 using namespace mlir; 44 using namespace mlir::edsc; 45 using namespace mlir::edsc::intrinsics; 46 using namespace mlir::LLVM; 47 using namespace mlir::linalg; 48 49 using llvm_add = ValueBuilder<LLVM::AddOp>; 50 using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>; 51 using llvm_constant = ValueBuilder<LLVM::ConstantOp>; 52 using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>; 53 using llvm_gep = ValueBuilder<LLVM::GEPOp>; 54 using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>; 55 using llvm_call = OperationBuilder<LLVM::CallOp>; 56 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>; 57 using llvm_load = ValueBuilder<LLVM::LoadOp>; 58 using llvm_store = OperationBuilder<LLVM::StoreOp>; 59 using llvm_select = ValueBuilder<LLVM::SelectOp>; 60 using llvm_mul = ValueBuilder<LLVM::MulOp>; 61 using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>; 62 using llvm_sub = ValueBuilder<LLVM::SubOp>; 63 using llvm_undef = ValueBuilder<LLVM::UndefOp>; 64 using llvm_urem = ValueBuilder<LLVM::URemOp>; 65 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>; 66 using llvm_return = OperationBuilder<LLVM::ReturnOp>; 67 68 template <typename T> 69 static LLVMType getPtrToElementType(T containerType, 70 LLVMTypeConverter &lowering) { 71 return lowering.convertType(containerType.getElementType()) 72 .template cast<LLVMType>() 73 .getPointerTo(); 74 } 75 76 /// Convert the given range descriptor type to the LLVMIR dialect. 77 /// Range descriptor contains the range bounds and the step as 64-bit integers. 78 /// 79 /// struct { 80 /// int64_t min; 81 /// int64_t max; 82 /// int64_t step; 83 /// }; 84 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) { 85 auto *context = t.getContext(); 86 auto int64Ty = converter.convertType(IntegerType::get(64, context)) 87 .cast<LLVM::LLVMType>(); 88 return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); 89 } 90 91 namespace { 92 /// EDSC-compatible wrapper for MemRefDescriptor. 93 class BaseViewConversionHelper { 94 public: 95 BaseViewConversionHelper(Type type) 96 : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {} 97 98 BaseViewConversionHelper(Value v) : d(v) {} 99 100 /// Wrappers around MemRefDescriptor that use EDSC builder and location. 101 Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } 102 void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); } 103 Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); } 104 void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); } 105 Value offset() { return d.offset(rewriter(), loc()); } 106 void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); } 107 Value size(unsigned i) { return d.size(rewriter(), loc(), i); } 108 void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); } 109 void setConstantSize(unsigned i, int64_t v) { 110 d.setConstantSize(rewriter(), loc(), i, v); 111 } 112 Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); } 113 void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); } 114 void setConstantStride(unsigned i, int64_t v) { 115 d.setConstantStride(rewriter(), loc(), i, v); 116 } 117 118 operator Value() { return d; } 119 120 private: 121 OpBuilder &rewriter() { return ScopedContext::getBuilder(); } 122 Location loc() { return ScopedContext::getLocation(); } 123 124 MemRefDescriptor d; 125 }; 126 127 // RangeOp creates a new range descriptor. 128 class RangeOpConversion : public ConvertToLLVMPattern { 129 public: 130 explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 131 : ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {} 132 133 LogicalResult 134 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 135 ConversionPatternRewriter &rewriter) const override { 136 auto rangeOp = cast<RangeOp>(op); 137 auto rangeDescriptorTy = 138 convertRangeType(rangeOp.getType().cast<RangeType>(), typeConverter); 139 140 edsc::ScopedContext context(rewriter, op->getLoc()); 141 142 // Fill in an aggregate value of the descriptor. 143 RangeOpOperandAdaptor adaptor(operands); 144 Value desc = llvm_undef(rangeDescriptorTy); 145 desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); 146 desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); 147 desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); 148 rewriter.replaceOp(op, desc); 149 return success(); 150 } 151 }; 152 153 // ReshapeOp creates a new view descriptor of the proper rank. 154 // For now, the only conversion supported is for target MemRef with static sizes 155 // and strides. 156 class ReshapeOpConversion : public ConvertToLLVMPattern { 157 public: 158 explicit ReshapeOpConversion(MLIRContext *context, 159 LLVMTypeConverter &lowering_) 160 : ConvertToLLVMPattern(ReshapeOp::getOperationName(), context, 161 lowering_) {} 162 163 LogicalResult 164 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 165 ConversionPatternRewriter &rewriter) const override { 166 auto reshapeOp = cast<ReshapeOp>(op); 167 MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>(); 168 169 if (!dstType.hasStaticShape()) 170 return failure(); 171 172 int64_t offset; 173 SmallVector<int64_t, 4> strides; 174 auto res = getStridesAndOffset(dstType, strides, offset); 175 if (failed(res) || llvm::any_of(strides, [](int64_t val) { 176 return ShapedType::isDynamicStrideOrOffset(val); 177 })) 178 return failure(); 179 180 edsc::ScopedContext context(rewriter, op->getLoc()); 181 ReshapeOpOperandAdaptor adaptor(operands); 182 BaseViewConversionHelper baseDesc(adaptor.view()); 183 BaseViewConversionHelper desc(typeConverter.convertType(dstType)); 184 desc.setAllocatedPtr(baseDesc.allocatedPtr()); 185 desc.setAlignedPtr(baseDesc.alignedPtr()); 186 desc.setOffset(baseDesc.offset()); 187 for (auto en : llvm::enumerate(dstType.getShape())) 188 desc.setConstantSize(en.index(), en.value()); 189 for (auto en : llvm::enumerate(strides)) 190 desc.setConstantStride(en.index(), en.value()); 191 rewriter.replaceOp(op, {desc}); 192 return success(); 193 } 194 }; 195 196 /// Conversion pattern that transforms a linalg.slice op into: 197 /// 1. An "undef" value for the ViewDescriptor. 198 /// 2. Updates to the ViewDescriptor to introduce the data ptr, offset, size 199 /// and stride corresponding to the region of memory within the bounds of 200 /// the parent view. 201 /// The linalg.slice op is replaced by the alloca'ed pointer. 202 class SliceOpConversion : public ConvertToLLVMPattern { 203 public: 204 explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 205 : ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {} 206 207 LogicalResult 208 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 209 ConversionPatternRewriter &rewriter) const override { 210 edsc::ScopedContext context(rewriter, op->getLoc()); 211 SliceOpOperandAdaptor adaptor(operands); 212 BaseViewConversionHelper baseDesc(adaptor.view()); 213 214 auto sliceOp = cast<SliceOp>(op); 215 auto memRefType = sliceOp.getBaseViewType(); 216 auto int64Ty = typeConverter.convertType(rewriter.getIntegerType(64)) 217 .cast<LLVM::LLVMType>(); 218 219 BaseViewConversionHelper desc( 220 typeConverter.convertType(sliceOp.getShapedType())); 221 222 // TODO(ntv): extract sizes and emit asserts. 223 SmallVector<Value, 4> strides(memRefType.getRank()); 224 for (int i = 0, e = memRefType.getRank(); i < e; ++i) 225 strides[i] = baseDesc.stride(i); 226 227 auto pos = [&rewriter](ArrayRef<int64_t> values) { 228 return rewriter.getI64ArrayAttr(values); 229 }; 230 231 // Compute base offset. 232 Value baseOffset = baseDesc.offset(); 233 for (int i = 0, e = memRefType.getRank(); i < e; ++i) { 234 Value indexing = adaptor.indexings()[i]; 235 Value min = indexing; 236 if (sliceOp.indexing(i).getType().isa<RangeType>()) 237 min = llvm_extractvalue(int64Ty, indexing, pos(0)); 238 baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i])); 239 } 240 241 // Insert the base and aligned pointers. 242 desc.setAllocatedPtr(baseDesc.allocatedPtr()); 243 desc.setAlignedPtr(baseDesc.alignedPtr()); 244 245 // Insert base offset. 246 desc.setOffset(baseOffset); 247 248 // Corner case, no sizes or strides: early return the descriptor. 249 if (sliceOp.getShapedType().getRank() == 0) 250 return rewriter.replaceOp(op, {desc}), success(); 251 252 Value zero = llvm_constant( 253 int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); 254 // Compute and insert view sizes (max - min along the range) and strides. 255 // Skip the non-range operands as they will be projected away from the view. 256 int numNewDims = 0; 257 for (auto en : llvm::enumerate(sliceOp.indexings())) { 258 Value indexing = en.value(); 259 if (indexing.getType().isa<RangeType>()) { 260 int rank = en.index(); 261 Value rangeDescriptor = adaptor.indexings()[rank]; 262 Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0)); 263 Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1)); 264 Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2)); 265 Value baseSize = baseDesc.size(rank); 266 267 // Bound upper by base view upper bound. 268 max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, 269 baseSize); 270 Value size = llvm_sub(max, min); 271 // Bound lower by zero. 272 size = 273 llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); 274 Value stride = llvm_mul(strides[rank], step); 275 desc.setSize(numNewDims, size); 276 desc.setStride(numNewDims, stride); 277 ++numNewDims; 278 } 279 } 280 281 rewriter.replaceOp(op, {desc}); 282 return success(); 283 } 284 }; 285 286 /// Conversion pattern that transforms a linalg.transpose op into: 287 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 288 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 289 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 290 /// and stride. Size and stride are permutations of the original values. 291 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 292 /// The linalg.transpose op is replaced by the alloca'ed pointer. 293 class TransposeOpConversion : public ConvertToLLVMPattern { 294 public: 295 explicit TransposeOpConversion(MLIRContext *context, 296 LLVMTypeConverter &lowering_) 297 : ConvertToLLVMPattern(TransposeOp::getOperationName(), context, 298 lowering_) {} 299 300 LogicalResult 301 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 302 ConversionPatternRewriter &rewriter) const override { 303 // Initialize the common boilerplate and alloca at the top of the FuncOp. 304 edsc::ScopedContext context(rewriter, op->getLoc()); 305 TransposeOpOperandAdaptor adaptor(operands); 306 BaseViewConversionHelper baseDesc(adaptor.view()); 307 308 auto transposeOp = cast<TransposeOp>(op); 309 // No permutation, early exit. 310 if (transposeOp.permutation().isIdentity()) 311 return rewriter.replaceOp(op, {baseDesc}), success(); 312 313 BaseViewConversionHelper desc( 314 typeConverter.convertType(transposeOp.getShapedType())); 315 316 // Copy the base and aligned pointers from the old descriptor to the new 317 // one. 318 desc.setAllocatedPtr(baseDesc.allocatedPtr()); 319 desc.setAlignedPtr(baseDesc.alignedPtr()); 320 321 // Copy the offset pointer from the old descriptor to the new one. 322 desc.setOffset(baseDesc.offset()); 323 324 // Iterate over the dimensions and apply size/stride permutation. 325 for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { 326 int sourcePos = en.index(); 327 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 328 desc.setSize(targetPos, baseDesc.size(sourcePos)); 329 desc.setStride(targetPos, baseDesc.stride(sourcePos)); 330 } 331 332 rewriter.replaceOp(op, {desc}); 333 return success(); 334 } 335 }; 336 337 // YieldOp produces and LLVM::ReturnOp. 338 class YieldOpConversion : public ConvertToLLVMPattern { 339 public: 340 explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 341 : ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {} 342 343 LogicalResult 344 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 345 ConversionPatternRewriter &rewriter) const override { 346 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands); 347 return success(); 348 } 349 }; 350 } // namespace 351 352 template <typename LinalgOp> 353 static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) { 354 return SmallVector<Type, 4>{op->getOperandTypes()}; 355 } 356 357 template <> 358 SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) { 359 auto ctx = op->getContext(); 360 auto indexedGenericOp = cast<IndexedGenericOp>(op); 361 auto numLoops = indexedGenericOp.getNumLoops(); 362 363 SmallVector<Type, 4> result; 364 result.reserve(numLoops + op->getNumOperands()); 365 for (unsigned i = 0; i < numLoops; ++i) { 366 result.push_back(IndexType::get(ctx)); 367 } 368 for (auto type : op->getOperandTypes()) { 369 result.push_back(type); 370 } 371 return result; 372 } 373 374 // Get a SymbolRefAttr containing the library function name for the LinalgOp. 375 // If the library function does not exist, insert a declaration. 376 template <typename LinalgOp> 377 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, 378 PatternRewriter &rewriter) { 379 auto linalgOp = cast<LinalgOp>(op); 380 auto fnName = linalgOp.getLibraryCallName(); 381 if (fnName.empty()) { 382 op->emitWarning("No library call defined for: ") << *op; 383 return {}; 384 } 385 386 // fnName is a dynamic std::String, unique it via a SymbolRefAttr. 387 FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); 388 auto module = op->getParentOfType<ModuleOp>(); 389 if (module.lookupSymbol(fnName)) { 390 return fnNameAttr; 391 } 392 393 SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op)); 394 assert(op->getNumResults() == 0 && 395 "Library call for linalg operation can be generated only for ops that " 396 "have void return types"); 397 auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext()); 398 399 OpBuilder::InsertionGuard guard(rewriter); 400 // Insert before module terminator. 401 rewriter.setInsertionPoint(module.getBody(), 402 std::prev(module.getBody()->end())); 403 FuncOp funcOp = 404 rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType, 405 ArrayRef<NamedAttribute>{}); 406 // Insert a function attribute that will trigger the emission of the 407 // corresponding `_mlir_ciface_xxx` interface so that external libraries see 408 // a normalized ABI. This interface is added during std to llvm conversion. 409 funcOp.setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext())); 410 return fnNameAttr; 411 } 412 413 namespace { 414 415 // LinalgOpConversion<LinalgOp> creates a new call to the 416 // `LinalgOp::getLibraryCallName()` function. 417 // The implementation of the function can be either in the same module or in an 418 // externally linked library. 419 template <typename LinalgOp> 420 class LinalgOpConversion : public OpRewritePattern<LinalgOp> { 421 public: 422 using OpRewritePattern<LinalgOp>::OpRewritePattern; 423 424 LogicalResult matchAndRewrite(LinalgOp op, 425 PatternRewriter &rewriter) const override { 426 auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter); 427 if (!libraryCallName) 428 return failure(); 429 430 rewriter.replaceOpWithNewOp<mlir::CallOp>( 431 op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands()); 432 return success(); 433 } 434 }; 435 436 /// Conversion pattern specialization for CopyOp. This kicks in when both input 437 /// and output permutations are left unspecified or are the identity. 438 template <> class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> { 439 public: 440 using OpRewritePattern<CopyOp>::OpRewritePattern; 441 442 LogicalResult matchAndRewrite(CopyOp op, 443 PatternRewriter &rewriter) const override { 444 auto inputPerm = op.inputPermutation(); 445 if (inputPerm.hasValue() && !inputPerm->isIdentity()) 446 return failure(); 447 auto outputPerm = op.outputPermutation(); 448 if (outputPerm.hasValue() && !outputPerm->isIdentity()) 449 return failure(); 450 451 auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter); 452 if (!libraryCallName) 453 return failure(); 454 455 rewriter.replaceOpWithNewOp<mlir::CallOp>( 456 op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands()); 457 return success(); 458 } 459 }; 460 461 /// Conversion pattern specialization for IndexedGenericOp. 462 template <> 463 class LinalgOpConversion<IndexedGenericOp> 464 : public OpRewritePattern<IndexedGenericOp> { 465 public: 466 using OpRewritePattern<IndexedGenericOp>::OpRewritePattern; 467 468 LogicalResult matchAndRewrite(IndexedGenericOp op, 469 PatternRewriter &rewriter) const override { 470 auto libraryCallName = 471 getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter); 472 if (!libraryCallName) 473 return failure(); 474 475 // TODO(pifon, ntv): Use induction variables values instead of zeros, when 476 // IndexedGenericOp is tiled. 477 auto zero = rewriter.create<mlir::ConstantOp>( 478 op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); 479 auto indexedGenericOp = cast<IndexedGenericOp>(op); 480 auto numLoops = indexedGenericOp.getNumLoops(); 481 SmallVector<Value, 4> operands; 482 operands.reserve(numLoops + op.getNumOperands()); 483 for (unsigned i = 0; i < numLoops; ++i) { 484 operands.push_back(zero); 485 } 486 for (auto operand : op.getOperands()) { 487 operands.push_back(operand); 488 } 489 rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(), 490 ArrayRef<Type>{}, operands); 491 return success(); 492 } 493 }; 494 495 /// A non-conversion rewrite pattern kicks in to convert CopyOp with 496 /// permutations into a sequence of TransposeOp and permutation-free CopyOp. 497 /// This interplays together with TransposeOpConversion and 498 /// LinalgConversion<CopyOp> to create a path to the LLVM dialect. 499 class CopyTransposeConversion : public OpRewritePattern<CopyOp> { 500 public: 501 using OpRewritePattern<CopyOp>::OpRewritePattern; 502 503 LogicalResult matchAndRewrite(CopyOp op, 504 PatternRewriter &rewriter) const override { 505 Value in = op.input(), out = op.output(); 506 507 // If either inputPerm or outputPerm are non-identities, insert transposes. 508 auto inputPerm = op.inputPermutation(); 509 if (inputPerm.hasValue() && !inputPerm->isIdentity()) 510 in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in, 511 AffineMapAttr::get(*inputPerm)); 512 auto outputPerm = op.outputPermutation(); 513 if (outputPerm.hasValue() && !outputPerm->isIdentity()) 514 out = rewriter.create<linalg::TransposeOp>( 515 op.getLoc(), out, AffineMapAttr::get(*outputPerm)); 516 517 // If nothing was transposed, fail and let the conversion kick in. 518 if (in == op.input() && out == op.output()) 519 return failure(); 520 521 rewriter.replaceOpWithNewOp<CopyOp>(op, in, out); 522 return success(); 523 } 524 }; 525 526 /// Populate the given list with patterns that convert from Linalg to Standard. 527 static void 528 populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns, 529 MLIRContext *ctx) { 530 // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant 531 // attribute values such as kernel striding and dilation. 532 // clang-format off 533 patterns.insert< 534 CopyTransposeConversion, 535 LinalgOpConversion<ConvOp>, 536 LinalgOpConversion<PoolingMaxOp>, 537 LinalgOpConversion<PoolingMinOp>, 538 LinalgOpConversion<PoolingSumOp>, 539 LinalgOpConversion<CopyOp>, 540 LinalgOpConversion<DotOp>, 541 LinalgOpConversion<FillOp>, 542 LinalgOpConversion<GenericOp>, 543 LinalgOpConversion<IndexedGenericOp>, 544 LinalgOpConversion<MatmulOp>, 545 LinalgOpConversion<MatvecOp>>(ctx); 546 // clang-format on 547 } 548 549 } // namespace 550 551 /// Populate the given list with patterns that convert from Linalg to LLVM. 552 void mlir::populateLinalgToLLVMConversionPatterns( 553 LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 554 MLIRContext *ctx) { 555 patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion, 556 TransposeOpConversion, YieldOpConversion>(ctx, converter); 557 558 // Populate the type conversions for the linalg types. 559 converter.addConversion( 560 [&](RangeType type) { return convertRangeType(type, converter); }); 561 } 562 563 namespace { 564 struct ConvertLinalgToLLVMPass : public ModulePass<ConvertLinalgToLLVMPass> { 565 /// Include the generated pass utilities. 566 #define GEN_PASS_ConvertLinalgToLLVM 567 #include "mlir/Conversion/Passes.h.inc" 568 569 void runOnModule() override; 570 }; 571 } // namespace 572 573 void ConvertLinalgToLLVMPass::runOnModule() { 574 auto module = getModule(); 575 576 // Convert to the LLVM IR dialect using the converter defined above. 577 OwningRewritePatternList patterns; 578 LLVMTypeConverter converter(&getContext()); 579 populateAffineToStdConversionPatterns(patterns, &getContext()); 580 populateLoopToStdConversionPatterns(patterns, &getContext()); 581 populateStdToLLVMConversionPatterns(converter, patterns); 582 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 583 populateVectorToLLVMConversionPatterns(converter, patterns); 584 populateLinalgToStandardConversionPatterns(patterns, &getContext()); 585 populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); 586 587 LLVMConversionTarget target(getContext()); 588 target.addDynamicallyLegalOp<FuncOp>( 589 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 590 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 591 if (failed(applyFullConversion(module, target, patterns, &converter))) 592 signalPassFailure(); 593 } 594 595 std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertLinalgToLLVMPass() { 596 return std::make_unique<ConvertLinalgToLLVMPass>(); 597 } 598