1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
10 
11 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
12 #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
15 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/Attributes.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/IR/Module.h"
27 #include "mlir/IR/Operation.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/IR/StandardTypes.h"
30 #include "mlir/IR/Types.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Pass/PassManager.h"
33 #include "mlir/Support/LogicalResult.h"
34 #include "mlir/Transforms/DialectConversion.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/SetVector.h"
37 #include "llvm/IR/DerivedTypes.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Type.h"
40 #include "llvm/Support/Allocator.h"
41 #include "llvm/Support/ErrorHandling.h"
42 
43 using namespace mlir;
44 using namespace mlir::edsc;
45 using namespace mlir::edsc::intrinsics;
46 using namespace mlir::LLVM;
47 using namespace mlir::linalg;
48 
49 using llvm_add = ValueBuilder<LLVM::AddOp>;
50 using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>;
51 using llvm_constant = ValueBuilder<LLVM::ConstantOp>;
52 using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>;
53 using llvm_gep = ValueBuilder<LLVM::GEPOp>;
54 using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>;
55 using llvm_call = OperationBuilder<LLVM::CallOp>;
56 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
57 using llvm_load = ValueBuilder<LLVM::LoadOp>;
58 using llvm_store = OperationBuilder<LLVM::StoreOp>;
59 using llvm_select = ValueBuilder<LLVM::SelectOp>;
60 using llvm_mul = ValueBuilder<LLVM::MulOp>;
61 using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>;
62 using llvm_sub = ValueBuilder<LLVM::SubOp>;
63 using llvm_undef = ValueBuilder<LLVM::UndefOp>;
64 using llvm_urem = ValueBuilder<LLVM::URemOp>;
65 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
66 using llvm_return = OperationBuilder<LLVM::ReturnOp>;
67 
68 template <typename T>
69 static LLVMType getPtrToElementType(T containerType,
70                                     LLVMTypeConverter &lowering) {
71   return lowering.convertType(containerType.getElementType())
72       .template cast<LLVMType>()
73       .getPointerTo();
74 }
75 
76 /// Convert the given range descriptor type to the LLVMIR dialect.
77 /// Range descriptor contains the range bounds and the step as 64-bit integers.
78 ///
79 /// struct {
80 ///   int64_t min;
81 ///   int64_t max;
82 ///   int64_t step;
83 /// };
84 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
85   auto *context = t.getContext();
86   auto int64Ty = converter.convertType(IntegerType::get(64, context))
87                      .cast<LLVM::LLVMType>();
88   return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
89 }
90 
91 namespace {
92 /// EDSC-compatible wrapper for MemRefDescriptor.
93 class BaseViewConversionHelper {
94 public:
95   BaseViewConversionHelper(Type type)
96       : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
97 
98   BaseViewConversionHelper(Value v) : d(v) {}
99 
100   /// Wrappers around MemRefDescriptor that use EDSC builder and location.
101   Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
102   void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
103   Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
104   void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
105   Value offset() { return d.offset(rewriter(), loc()); }
106   void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
107   Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
108   void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
109   void setConstantSize(unsigned i, int64_t v) {
110     d.setConstantSize(rewriter(), loc(), i, v);
111   }
112   Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
113   void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
114   void setConstantStride(unsigned i, int64_t v) {
115     d.setConstantStride(rewriter(), loc(), i, v);
116   }
117 
118   operator Value() { return d; }
119 
120 private:
121   OpBuilder &rewriter() { return ScopedContext::getBuilder(); }
122   Location loc() { return ScopedContext::getLocation(); }
123 
124   MemRefDescriptor d;
125 };
126 
127 // RangeOp creates a new range descriptor.
128 class RangeOpConversion : public ConvertToLLVMPattern {
129 public:
130   explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
131       : ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {}
132 
133   LogicalResult
134   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
135                   ConversionPatternRewriter &rewriter) const override {
136     auto rangeOp = cast<RangeOp>(op);
137     auto rangeDescriptorTy =
138         convertRangeType(rangeOp.getType().cast<RangeType>(), typeConverter);
139 
140     edsc::ScopedContext context(rewriter, op->getLoc());
141 
142     // Fill in an aggregate value of the descriptor.
143     RangeOpOperandAdaptor adaptor(operands);
144     Value desc = llvm_undef(rangeDescriptorTy);
145     desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
146     desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
147     desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
148     rewriter.replaceOp(op, desc);
149     return success();
150   }
151 };
152 
153 // ReshapeOp creates a new view descriptor of the proper rank.
154 // For now, the only conversion supported is for target MemRef with static sizes
155 // and strides.
156 class ReshapeOpConversion : public ConvertToLLVMPattern {
157 public:
158   explicit ReshapeOpConversion(MLIRContext *context,
159                                LLVMTypeConverter &lowering_)
160       : ConvertToLLVMPattern(ReshapeOp::getOperationName(), context,
161                              lowering_) {}
162 
163   LogicalResult
164   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
165                   ConversionPatternRewriter &rewriter) const override {
166     auto reshapeOp = cast<ReshapeOp>(op);
167     MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>();
168 
169     if (!dstType.hasStaticShape())
170       return failure();
171 
172     int64_t offset;
173     SmallVector<int64_t, 4> strides;
174     auto res = getStridesAndOffset(dstType, strides, offset);
175     if (failed(res) || llvm::any_of(strides, [](int64_t val) {
176           return ShapedType::isDynamicStrideOrOffset(val);
177         }))
178       return failure();
179 
180     edsc::ScopedContext context(rewriter, op->getLoc());
181     ReshapeOpOperandAdaptor adaptor(operands);
182     BaseViewConversionHelper baseDesc(adaptor.view());
183     BaseViewConversionHelper desc(typeConverter.convertType(dstType));
184     desc.setAllocatedPtr(baseDesc.allocatedPtr());
185     desc.setAlignedPtr(baseDesc.alignedPtr());
186     desc.setOffset(baseDesc.offset());
187     for (auto en : llvm::enumerate(dstType.getShape()))
188       desc.setConstantSize(en.index(), en.value());
189     for (auto en : llvm::enumerate(strides))
190       desc.setConstantStride(en.index(), en.value());
191     rewriter.replaceOp(op, {desc});
192     return success();
193   }
194 };
195 
196 /// Conversion pattern that transforms a linalg.slice op into:
197 ///   1. An "undef" value for the ViewDescriptor.
198 ///   2. Updates to the ViewDescriptor to introduce the data ptr, offset, size
199 ///      and stride corresponding to the region of memory within the bounds of
200 ///      the parent view.
201 /// The linalg.slice op is replaced by the alloca'ed pointer.
202 class SliceOpConversion : public ConvertToLLVMPattern {
203 public:
204   explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
205       : ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {}
206 
207   LogicalResult
208   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
209                   ConversionPatternRewriter &rewriter) const override {
210     edsc::ScopedContext context(rewriter, op->getLoc());
211     SliceOpOperandAdaptor adaptor(operands);
212     BaseViewConversionHelper baseDesc(adaptor.view());
213 
214     auto sliceOp = cast<SliceOp>(op);
215     auto memRefType = sliceOp.getBaseViewType();
216     auto int64Ty = typeConverter.convertType(rewriter.getIntegerType(64))
217                        .cast<LLVM::LLVMType>();
218 
219     BaseViewConversionHelper desc(
220         typeConverter.convertType(sliceOp.getShapedType()));
221 
222     // TODO(ntv): extract sizes and emit asserts.
223     SmallVector<Value, 4> strides(memRefType.getRank());
224     for (int i = 0, e = memRefType.getRank(); i < e; ++i)
225       strides[i] = baseDesc.stride(i);
226 
227     auto pos = [&rewriter](ArrayRef<int64_t> values) {
228       return rewriter.getI64ArrayAttr(values);
229     };
230 
231     // Compute base offset.
232     Value baseOffset = baseDesc.offset();
233     for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
234       Value indexing = adaptor.indexings()[i];
235       Value min = indexing;
236       if (sliceOp.indexing(i).getType().isa<RangeType>())
237         min = llvm_extractvalue(int64Ty, indexing, pos(0));
238       baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i]));
239     }
240 
241     // Insert the base and aligned pointers.
242     desc.setAllocatedPtr(baseDesc.allocatedPtr());
243     desc.setAlignedPtr(baseDesc.alignedPtr());
244 
245     // Insert base offset.
246     desc.setOffset(baseOffset);
247 
248     // Corner case, no sizes or strides: early return the descriptor.
249     if (sliceOp.getShapedType().getRank() == 0)
250       return rewriter.replaceOp(op, {desc}), success();
251 
252     Value zero = llvm_constant(
253         int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
254     // Compute and insert view sizes (max - min along the range) and strides.
255     // Skip the non-range operands as they will be projected away from the view.
256     int numNewDims = 0;
257     for (auto en : llvm::enumerate(sliceOp.indexings())) {
258       Value indexing = en.value();
259       if (indexing.getType().isa<RangeType>()) {
260         int rank = en.index();
261         Value rangeDescriptor = adaptor.indexings()[rank];
262         Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0));
263         Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1));
264         Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2));
265         Value baseSize = baseDesc.size(rank);
266 
267         // Bound upper by base view upper bound.
268         max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
269                           baseSize);
270         Value size = llvm_sub(max, min);
271         // Bound lower by zero.
272         size =
273             llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
274         Value stride = llvm_mul(strides[rank], step);
275         desc.setSize(numNewDims, size);
276         desc.setStride(numNewDims, stride);
277         ++numNewDims;
278       }
279     }
280 
281     rewriter.replaceOp(op, {desc});
282     return success();
283   }
284 };
285 
286 /// Conversion pattern that transforms a linalg.transpose op into:
287 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
288 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
289 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
290 ///      and stride. Size and stride are permutations of the original values.
291 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
292 /// The linalg.transpose op is replaced by the alloca'ed pointer.
293 class TransposeOpConversion : public ConvertToLLVMPattern {
294 public:
295   explicit TransposeOpConversion(MLIRContext *context,
296                                  LLVMTypeConverter &lowering_)
297       : ConvertToLLVMPattern(TransposeOp::getOperationName(), context,
298                              lowering_) {}
299 
300   LogicalResult
301   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
302                   ConversionPatternRewriter &rewriter) const override {
303     // Initialize the common boilerplate and alloca at the top of the FuncOp.
304     edsc::ScopedContext context(rewriter, op->getLoc());
305     TransposeOpOperandAdaptor adaptor(operands);
306     BaseViewConversionHelper baseDesc(adaptor.view());
307 
308     auto transposeOp = cast<TransposeOp>(op);
309     // No permutation, early exit.
310     if (transposeOp.permutation().isIdentity())
311       return rewriter.replaceOp(op, {baseDesc}), success();
312 
313     BaseViewConversionHelper desc(
314         typeConverter.convertType(transposeOp.getShapedType()));
315 
316     // Copy the base and aligned pointers from the old descriptor to the new
317     // one.
318     desc.setAllocatedPtr(baseDesc.allocatedPtr());
319     desc.setAlignedPtr(baseDesc.alignedPtr());
320 
321     // Copy the offset pointer from the old descriptor to the new one.
322     desc.setOffset(baseDesc.offset());
323 
324     // Iterate over the dimensions and apply size/stride permutation.
325     for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
326       int sourcePos = en.index();
327       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
328       desc.setSize(targetPos, baseDesc.size(sourcePos));
329       desc.setStride(targetPos, baseDesc.stride(sourcePos));
330     }
331 
332     rewriter.replaceOp(op, {desc});
333     return success();
334   }
335 };
336 
337 // YieldOp produces and LLVM::ReturnOp.
338 class YieldOpConversion : public ConvertToLLVMPattern {
339 public:
340   explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
341       : ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {}
342 
343   LogicalResult
344   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
345                   ConversionPatternRewriter &rewriter) const override {
346     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
347     return success();
348   }
349 };
350 } // namespace
351 
352 template <typename LinalgOp>
353 static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) {
354   return SmallVector<Type, 4>{op->getOperandTypes()};
355 }
356 
357 template <>
358 SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) {
359   auto ctx = op->getContext();
360   auto indexedGenericOp = cast<IndexedGenericOp>(op);
361   auto numLoops = indexedGenericOp.getNumLoops();
362 
363   SmallVector<Type, 4> result;
364   result.reserve(numLoops + op->getNumOperands());
365   for (unsigned i = 0; i < numLoops; ++i) {
366     result.push_back(IndexType::get(ctx));
367   }
368   for (auto type : op->getOperandTypes()) {
369     result.push_back(type);
370   }
371   return result;
372 }
373 
374 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
375 // If the library function does not exist, insert a declaration.
376 template <typename LinalgOp>
377 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
378                                                  PatternRewriter &rewriter) {
379   auto linalgOp = cast<LinalgOp>(op);
380   auto fnName = linalgOp.getLibraryCallName();
381   if (fnName.empty()) {
382     op->emitWarning("No library call defined for: ") << *op;
383     return {};
384   }
385 
386   // fnName is a dynamic std::String, unique it via a SymbolRefAttr.
387   FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
388   auto module = op->getParentOfType<ModuleOp>();
389   if (module.lookupSymbol(fnName)) {
390     return fnNameAttr;
391   }
392 
393   SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op));
394   assert(op->getNumResults() == 0 &&
395          "Library call for linalg operation can be generated only for ops that "
396          "have void return types");
397   auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
398 
399   OpBuilder::InsertionGuard guard(rewriter);
400   // Insert before module terminator.
401   rewriter.setInsertionPoint(module.getBody(),
402                              std::prev(module.getBody()->end()));
403   FuncOp funcOp =
404       rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType,
405                               ArrayRef<NamedAttribute>{});
406   // Insert a function attribute that will trigger the emission of the
407   // corresponding `_mlir_ciface_xxx` interface so that external libraries see
408   // a normalized ABI. This interface is added during std to llvm conversion.
409   funcOp.setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext()));
410   return fnNameAttr;
411 }
412 
413 namespace {
414 
415 // LinalgOpConversion<LinalgOp> creates a new call to the
416 // `LinalgOp::getLibraryCallName()` function.
417 // The implementation of the function can be either in the same module or in an
418 // externally linked library.
419 template <typename LinalgOp>
420 class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
421 public:
422   using OpRewritePattern<LinalgOp>::OpRewritePattern;
423 
424   LogicalResult matchAndRewrite(LinalgOp op,
425                                 PatternRewriter &rewriter) const override {
426     auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
427     if (!libraryCallName)
428       return failure();
429 
430     rewriter.replaceOpWithNewOp<mlir::CallOp>(
431         op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
432     return success();
433   }
434 };
435 
436 /// Conversion pattern specialization for CopyOp. This kicks in when both input
437 /// and output permutations are left unspecified or are the identity.
438 template <> class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
439 public:
440   using OpRewritePattern<CopyOp>::OpRewritePattern;
441 
442   LogicalResult matchAndRewrite(CopyOp op,
443                                 PatternRewriter &rewriter) const override {
444     auto inputPerm = op.inputPermutation();
445     if (inputPerm.hasValue() && !inputPerm->isIdentity())
446       return failure();
447     auto outputPerm = op.outputPermutation();
448     if (outputPerm.hasValue() && !outputPerm->isIdentity())
449       return failure();
450 
451     auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
452     if (!libraryCallName)
453       return failure();
454 
455     rewriter.replaceOpWithNewOp<mlir::CallOp>(
456         op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
457     return success();
458   }
459 };
460 
461 /// Conversion pattern specialization for IndexedGenericOp.
462 template <>
463 class LinalgOpConversion<IndexedGenericOp>
464     : public OpRewritePattern<IndexedGenericOp> {
465 public:
466   using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
467 
468   LogicalResult matchAndRewrite(IndexedGenericOp op,
469                                 PatternRewriter &rewriter) const override {
470     auto libraryCallName =
471         getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
472     if (!libraryCallName)
473       return failure();
474 
475     // TODO(pifon, ntv): Use induction variables values instead of zeros, when
476     // IndexedGenericOp is tiled.
477     auto zero = rewriter.create<mlir::ConstantOp>(
478         op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
479     auto indexedGenericOp = cast<IndexedGenericOp>(op);
480     auto numLoops = indexedGenericOp.getNumLoops();
481     SmallVector<Value, 4> operands;
482     operands.reserve(numLoops + op.getNumOperands());
483     for (unsigned i = 0; i < numLoops; ++i) {
484       operands.push_back(zero);
485     }
486     for (auto operand : op.getOperands()) {
487       operands.push_back(operand);
488     }
489     rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
490                                               ArrayRef<Type>{}, operands);
491     return success();
492   }
493 };
494 
495 /// A non-conversion rewrite pattern kicks in to convert CopyOp with
496 /// permutations into a sequence of TransposeOp and permutation-free CopyOp.
497 /// This interplays together with TransposeOpConversion and
498 /// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
499 class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
500 public:
501   using OpRewritePattern<CopyOp>::OpRewritePattern;
502 
503   LogicalResult matchAndRewrite(CopyOp op,
504                                 PatternRewriter &rewriter) const override {
505     Value in = op.input(), out = op.output();
506 
507     // If either inputPerm or outputPerm are non-identities, insert transposes.
508     auto inputPerm = op.inputPermutation();
509     if (inputPerm.hasValue() && !inputPerm->isIdentity())
510       in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
511                                                 AffineMapAttr::get(*inputPerm));
512     auto outputPerm = op.outputPermutation();
513     if (outputPerm.hasValue() && !outputPerm->isIdentity())
514       out = rewriter.create<linalg::TransposeOp>(
515           op.getLoc(), out, AffineMapAttr::get(*outputPerm));
516 
517     // If nothing was transposed, fail and let the conversion kick in.
518     if (in == op.input() && out == op.output())
519       return failure();
520 
521     rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
522     return success();
523   }
524 };
525 
526 /// Populate the given list with patterns that convert from Linalg to Standard.
527 static void
528 populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
529                                            MLIRContext *ctx) {
530   // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant
531   // attribute values such as kernel striding and dilation.
532   // clang-format off
533   patterns.insert<
534       CopyTransposeConversion,
535       LinalgOpConversion<ConvOp>,
536       LinalgOpConversion<PoolingMaxOp>,
537       LinalgOpConversion<PoolingMinOp>,
538       LinalgOpConversion<PoolingSumOp>,
539       LinalgOpConversion<CopyOp>,
540       LinalgOpConversion<DotOp>,
541       LinalgOpConversion<FillOp>,
542       LinalgOpConversion<GenericOp>,
543       LinalgOpConversion<IndexedGenericOp>,
544       LinalgOpConversion<MatmulOp>,
545       LinalgOpConversion<MatvecOp>>(ctx);
546   // clang-format on
547 }
548 
549 } // namespace
550 
551 /// Populate the given list with patterns that convert from Linalg to LLVM.
552 void mlir::populateLinalgToLLVMConversionPatterns(
553     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
554     MLIRContext *ctx) {
555   patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
556                   TransposeOpConversion, YieldOpConversion>(ctx, converter);
557 
558   // Populate the type conversions for the linalg types.
559   converter.addConversion(
560       [&](RangeType type) { return convertRangeType(type, converter); });
561 }
562 
563 namespace {
564 struct ConvertLinalgToLLVMPass : public ModulePass<ConvertLinalgToLLVMPass> {
565 /// Include the generated pass utilities.
566 #define GEN_PASS_ConvertLinalgToLLVM
567 #include "mlir/Conversion/Passes.h.inc"
568 
569   void runOnModule() override;
570 };
571 } // namespace
572 
573 void ConvertLinalgToLLVMPass::runOnModule() {
574   auto module = getModule();
575 
576   // Convert to the LLVM IR dialect using the converter defined above.
577   OwningRewritePatternList patterns;
578   LLVMTypeConverter converter(&getContext());
579   populateAffineToStdConversionPatterns(patterns, &getContext());
580   populateLoopToStdConversionPatterns(patterns, &getContext());
581   populateStdToLLVMConversionPatterns(converter, patterns);
582   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
583   populateVectorToLLVMConversionPatterns(converter, patterns);
584   populateLinalgToStandardConversionPatterns(patterns, &getContext());
585   populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
586 
587   LLVMConversionTarget target(getContext());
588   target.addDynamicallyLegalOp<FuncOp>(
589       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
590   target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
591   if (failed(applyFullConversion(module, target, patterns, &converter)))
592     signalPassFailure();
593 }
594 
595 std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertLinalgToLLVMPass() {
596   return std::make_unique<ConvertLinalgToLLVMPass>();
597 }
598