1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
10 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
11 #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
14 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
20 #include "mlir/EDSC/Builders.h"
21 #include "mlir/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Attributes.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/MLIRContext.h"
27 #include "mlir/IR/Module.h"
28 #include "mlir/IR/Operation.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/IR/StandardTypes.h"
31 #include "mlir/IR/Types.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Pass/PassManager.h"
34 #include "mlir/Support/LogicalResult.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/Passes.h"
37 
38 #include "llvm/ADT/SetVector.h"
39 #include "llvm/IR/DerivedTypes.h"
40 #include "llvm/IR/Module.h"
41 #include "llvm/IR/Type.h"
42 #include "llvm/Support/Allocator.h"
43 #include "llvm/Support/ErrorHandling.h"
44 
45 using namespace mlir;
46 using namespace mlir::edsc;
47 using namespace mlir::edsc::intrinsics;
48 using namespace mlir::LLVM;
49 using namespace mlir::linalg;
50 using namespace mlir::linalg::intrinsics;
51 
52 using add = ValueBuilder<mlir::LLVM::AddOp>;
53 using addi = ValueBuilder<mlir::AddIOp>;
54 using bitcast = ValueBuilder<mlir::LLVM::BitcastOp>;
55 using cmpi = ValueBuilder<mlir::CmpIOp>;
56 using constant = ValueBuilder<mlir::LLVM::ConstantOp>;
57 using extractvalue = ValueBuilder<mlir::LLVM::ExtractValueOp>;
58 using gep = ValueBuilder<mlir::LLVM::GEPOp>;
59 using insertvalue = ValueBuilder<mlir::LLVM::InsertValueOp>;
60 using llvm_call = OperationBuilder<mlir::LLVM::CallOp>;
61 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
62 using llvm_load = ValueBuilder<LLVM::LoadOp>;
63 using llvm_store = OperationBuilder<LLVM::StoreOp>;
64 using llvm_select = ValueBuilder<LLVM::SelectOp>;
65 using mul = ValueBuilder<mlir::LLVM::MulOp>;
66 using ptrtoint = ValueBuilder<mlir::LLVM::PtrToIntOp>;
67 using sub = ValueBuilder<mlir::LLVM::SubOp>;
68 using llvm_undef = ValueBuilder<mlir::LLVM::UndefOp>;
69 using urem = ValueBuilder<mlir::LLVM::URemOp>;
70 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
71 using llvm_return = OperationBuilder<LLVM::ReturnOp>;
72 
73 namespace {
74 
75 template <typename T>
76 static LLVMType getPtrToElementType(T containerType,
77                                     LLVMTypeConverter &lowering) {
78   return lowering.convertType(containerType.getElementType())
79       .template cast<LLVMType>()
80       .getPointerTo();
81 }
82 
83 // Convert the given type to the LLVM IR Dialect type.  The following
84 // conversions are supported:
85 //   - an Index type is converted into an LLVM integer type with pointer
86 //     bitwidth (analogous to intptr_t in C);
87 //   - an Integer type is converted into an LLVM integer type of the same width;
88 //   - an F32 type is converted into an LLVM float type
89 //   - a Buffer, Range or View is converted into an LLVM structure type
90 //     containing the respective dynamic values.
91 static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
92   auto *context = t.getContext();
93   auto int64Ty = lowering.convertType(IntegerType::get(64, context))
94                      .cast<LLVM::LLVMType>();
95 
96   // Range descriptor contains the range bounds and the step as 64-bit integers.
97   //
98   // struct {
99   //   int64_t min;
100   //   int64_t max;
101   //   int64_t step;
102   // };
103   if (t.isa<RangeType>())
104     return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
105 
106   return Type();
107 }
108 
109 /// EDSC-compatible wrapper for MemRefDescriptor.
110 class BaseViewConversionHelper {
111 public:
112   BaseViewConversionHelper(Type type)
113       : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
114 
115   BaseViewConversionHelper(Value v) : d(v) {}
116 
117   /// Wrappers around MemRefDescriptor that use EDSC builder and location.
118   Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
119   void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
120   Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
121   void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
122   Value offset() { return d.offset(rewriter(), loc()); }
123   void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
124   Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
125   void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
126   void setConstantSize(unsigned i, int64_t v) {
127     d.setConstantSize(rewriter(), loc(), i, v);
128   }
129   Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
130   void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
131   void setConstantStride(unsigned i, int64_t v) {
132     d.setConstantStride(rewriter(), loc(), i, v);
133   }
134 
135   operator Value() { return d; }
136 
137 private:
138   OpBuilder &rewriter() { return ScopedContext::getBuilder(); }
139   Location loc() { return ScopedContext::getLocation(); }
140 
141   MemRefDescriptor d;
142 };
143 
144 // RangeOp creates a new range descriptor.
145 class RangeOpConversion : public LLVMOpLowering {
146 public:
147   explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
148       : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
149 
150   PatternMatchResult
151   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
152                   ConversionPatternRewriter &rewriter) const override {
153     auto rangeOp = cast<RangeOp>(op);
154     auto rangeDescriptorTy =
155         convertLinalgType(rangeOp.getResult().getType(), lowering);
156 
157     edsc::ScopedContext context(rewriter, op->getLoc());
158 
159     // Fill in an aggregate value of the descriptor.
160     RangeOpOperandAdaptor adaptor(operands);
161     Value desc = llvm_undef(rangeDescriptorTy);
162     desc = insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
163     desc = insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
164     desc = insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
165     rewriter.replaceOp(op, desc);
166     return matchSuccess();
167   }
168 };
169 
170 // ReshapeOp creates a new view descriptor of the proper rank.
171 // For now, the only conversion supported is for target MemRef with static sizes
172 // and strides.
173 class ReshapeOpConversion : public LLVMOpLowering {
174 public:
175   explicit ReshapeOpConversion(MLIRContext *context,
176                                LLVMTypeConverter &lowering_)
177       : LLVMOpLowering(ReshapeOp::getOperationName(), context, lowering_) {}
178 
179   PatternMatchResult
180   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
181                   ConversionPatternRewriter &rewriter) const override {
182     auto reshapeOp = cast<ReshapeOp>(op);
183     MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>();
184 
185     if (!dstType.hasStaticShape())
186       return matchFailure();
187 
188     int64_t offset;
189     SmallVector<int64_t, 4> strides;
190     auto res = getStridesAndOffset(dstType, strides, offset);
191     if (failed(res) || llvm::any_of(strides, [](int64_t val) {
192           return ShapedType::isDynamicStrideOrOffset(val);
193         }))
194       return matchFailure();
195 
196     edsc::ScopedContext context(rewriter, op->getLoc());
197     ReshapeOpOperandAdaptor adaptor(operands);
198     BaseViewConversionHelper baseDesc(adaptor.view());
199     BaseViewConversionHelper desc(lowering.convertType(dstType));
200     desc.setAllocatedPtr(baseDesc.allocatedPtr());
201     desc.setAlignedPtr(baseDesc.alignedPtr());
202     desc.setOffset(baseDesc.offset());
203     for (auto en : llvm::enumerate(dstType.getShape()))
204       desc.setConstantSize(en.index(), en.value());
205     for (auto en : llvm::enumerate(strides))
206       desc.setConstantStride(en.index(), en.value());
207     rewriter.replaceOp(op, {desc});
208     return matchSuccess();
209   }
210 };
211 
212 /// Conversion pattern that transforms a linalg.slice op into:
213 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
214 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
215 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
216 ///      and stride corresponding to the region of memory within the bounds of
217 ///      the parent view.
218 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
219 /// The linalg.slice op is replaced by the alloca'ed pointer.
220 class SliceOpConversion : public LLVMOpLowering {
221 public:
222   explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
223       : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
224 
225   PatternMatchResult
226   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
227                   ConversionPatternRewriter &rewriter) const override {
228     edsc::ScopedContext context(rewriter, op->getLoc());
229     SliceOpOperandAdaptor adaptor(operands);
230     BaseViewConversionHelper baseDesc(adaptor.view());
231 
232     auto sliceOp = cast<SliceOp>(op);
233     auto memRefType = sliceOp.getBaseViewType();
234     auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
235                        .cast<LLVM::LLVMType>();
236 
237     BaseViewConversionHelper desc(
238         lowering.convertType(sliceOp.getShapedType()));
239 
240     // TODO(ntv): extract sizes and emit asserts.
241     SmallVector<Value, 4> strides(memRefType.getRank());
242     for (int i = 0, e = memRefType.getRank(); i < e; ++i)
243       strides[i] = baseDesc.stride(i);
244 
245     auto pos = [&rewriter](ArrayRef<int64_t> values) {
246       return rewriter.getI64ArrayAttr(values);
247     };
248 
249     // Compute base offset.
250     Value baseOffset = baseDesc.offset();
251     for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
252       Value indexing = adaptor.indexings()[i];
253       Value min = indexing;
254       if (sliceOp.indexing(i).getType().isa<RangeType>())
255         min = extractvalue(int64Ty, indexing, pos(0));
256       baseOffset = add(baseOffset, mul(min, strides[i]));
257     }
258 
259     // Insert the base and aligned pointers.
260     desc.setAllocatedPtr(baseDesc.allocatedPtr());
261     desc.setAlignedPtr(baseDesc.alignedPtr());
262 
263     // Insert base offset.
264     desc.setOffset(baseOffset);
265 
266     // Corner case, no sizes or strides: early return the descriptor.
267     if (sliceOp.getShapedType().getRank() == 0)
268       return rewriter.replaceOp(op, {desc}), matchSuccess();
269 
270     Value zero =
271         constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
272     // Compute and insert view sizes (max - min along the range) and strides.
273     // Skip the non-range operands as they will be projected away from the view.
274     int numNewDims = 0;
275     for (auto en : llvm::enumerate(sliceOp.indexings())) {
276       Value indexing = en.value();
277       if (indexing.getType().isa<RangeType>()) {
278         int rank = en.index();
279         Value rangeDescriptor = adaptor.indexings()[rank];
280         Value min = extractvalue(int64Ty, rangeDescriptor, pos(0));
281         Value max = extractvalue(int64Ty, rangeDescriptor, pos(1));
282         Value step = extractvalue(int64Ty, rangeDescriptor, pos(2));
283         Value baseSize = baseDesc.size(rank);
284 
285         // Bound upper by base view upper bound.
286         max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
287                           baseSize);
288         Value size = sub(max, min);
289         // Bound lower by zero.
290         size =
291             llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
292         Value stride = mul(strides[rank], step);
293         desc.setSize(numNewDims, size);
294         desc.setStride(numNewDims, stride);
295         ++numNewDims;
296       }
297     }
298 
299     rewriter.replaceOp(op, {desc});
300     return matchSuccess();
301   }
302 };
303 
304 /// Conversion pattern that transforms a linalg.transpose op into:
305 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
306 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
307 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
308 ///      and stride. Size and stride are permutations of the original values.
309 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
310 /// The linalg.transpose op is replaced by the alloca'ed pointer.
311 class TransposeOpConversion : public LLVMOpLowering {
312 public:
313   explicit TransposeOpConversion(MLIRContext *context,
314                                  LLVMTypeConverter &lowering_)
315       : LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {}
316 
317   PatternMatchResult
318   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
319                   ConversionPatternRewriter &rewriter) const override {
320     // Initialize the common boilerplate and alloca at the top of the FuncOp.
321     edsc::ScopedContext context(rewriter, op->getLoc());
322     TransposeOpOperandAdaptor adaptor(operands);
323     BaseViewConversionHelper baseDesc(adaptor.view());
324 
325     auto transposeOp = cast<TransposeOp>(op);
326     // No permutation, early exit.
327     if (transposeOp.permutation().isIdentity())
328       return rewriter.replaceOp(op, {baseDesc}), matchSuccess();
329 
330     BaseViewConversionHelper desc(
331         lowering.convertType(transposeOp.getShapedType()));
332 
333     // Copy the base and aligned pointers from the old descriptor to the new
334     // one.
335     desc.setAllocatedPtr(baseDesc.allocatedPtr());
336     desc.setAlignedPtr(baseDesc.alignedPtr());
337 
338     // Copy the offset pointer from the old descriptor to the new one.
339     desc.setOffset(baseDesc.offset());
340 
341     // Iterate over the dimensions and apply size/stride permutation.
342     for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
343       int sourcePos = en.index();
344       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
345       desc.setSize(targetPos, baseDesc.size(sourcePos));
346       desc.setStride(targetPos, baseDesc.stride(sourcePos));
347     }
348 
349     rewriter.replaceOp(op, {desc});
350     return matchSuccess();
351   }
352 };
353 
354 // YieldOp produces and LLVM::ReturnOp.
355 class YieldOpConversion : public LLVMOpLowering {
356 public:
357   explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
358       : LLVMOpLowering(YieldOp::getOperationName(), context, lowering_) {}
359 
360   PatternMatchResult
361   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
362                   ConversionPatternRewriter &rewriter) const override {
363     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
364     return matchSuccess();
365   }
366 };
367 
368 template <typename LinalgOp>
369 static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) {
370   return SmallVector<Type, 4>{op->getOperandTypes()};
371 }
372 
373 template <>
374 SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) {
375   auto ctx = op->getContext();
376   auto indexedGenericOp = cast<IndexedGenericOp>(op);
377   auto numLoops = indexedGenericOp.getNumLoops();
378 
379   SmallVector<Type, 4> result;
380   result.reserve(numLoops + op->getNumOperands());
381   for (unsigned i = 0; i < numLoops; ++i) {
382     result.push_back(IndexType::get(ctx));
383   }
384   for (auto type : op->getOperandTypes()) {
385     result.push_back(type);
386   }
387   return result;
388 }
389 
390 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
391 // If the library function does not exist, insert a declaration.
392 template <typename LinalgOp>
393 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
394                                                  PatternRewriter &rewriter) {
395   auto linalgOp = cast<LinalgOp>(op);
396   auto fnName = linalgOp.getLibraryCallName();
397   if (fnName.empty()) {
398     op->emitWarning("No library call defined for: ") << *op;
399     return {};
400   }
401 
402   // fnName is a dynamic std::String, unique it via a SymbolRefAttr.
403   FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
404   auto module = op->getParentOfType<ModuleOp>();
405   if (module.lookupSymbol(fnName)) {
406     return fnNameAttr;
407   }
408 
409   SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op));
410   assert(op->getNumResults() == 0 &&
411          "Library call for linalg operation can be generated only for ops that "
412          "have void return types");
413   auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
414 
415   OpBuilder::InsertionGuard guard(rewriter);
416   // Insert before module terminator.
417   rewriter.setInsertionPoint(module.getBody(),
418                              std::prev(module.getBody()->end()));
419   rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType,
420                           ArrayRef<NamedAttribute>{});
421   return fnNameAttr;
422 }
423 
424 } // namespace
425 
426 Type LinalgTypeConverter::convertType(Type t) {
427   if (auto result = LLVMTypeConverter::convertType(t))
428     return result;
429   return convertLinalgType(t, *this);
430 }
431 
432 namespace {
433 
434 // LinalgOpConversion<LinalgOp> creates a new call to the
435 // `LinalgOp::getLibraryCallName()` function.
436 // The implementation of the function can be either in the same module or in an
437 // externally linked library.
438 template <typename LinalgOp>
439 class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
440 public:
441   using OpRewritePattern<LinalgOp>::OpRewritePattern;
442 
443   PatternMatchResult matchAndRewrite(LinalgOp op,
444                                      PatternRewriter &rewriter) const override {
445     auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
446     if (!libraryCallName)
447       return this->matchFailure();
448 
449     rewriter.replaceOpWithNewOp<mlir::CallOp>(
450         op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
451     return this->matchSuccess();
452   }
453 };
454 
455 /// Conversion pattern specialization for CopyOp. This kicks in when both input
456 /// and output permutations are left unspecified or are the identity.
457 template <> class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
458 public:
459   using OpRewritePattern<CopyOp>::OpRewritePattern;
460 
461   PatternMatchResult matchAndRewrite(CopyOp op,
462                                      PatternRewriter &rewriter) const override {
463     auto inputPerm = op.inputPermutation();
464     if (inputPerm.hasValue() && !inputPerm->isIdentity())
465       return matchFailure();
466     auto outputPerm = op.outputPermutation();
467     if (outputPerm.hasValue() && !outputPerm->isIdentity())
468       return matchFailure();
469 
470     auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
471     if (!libraryCallName)
472       return matchFailure();
473 
474     rewriter.replaceOpWithNewOp<mlir::CallOp>(
475         op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
476     return matchSuccess();
477   }
478 };
479 
480 /// Conversion pattern specialization for IndexedGenericOp.
481 template <>
482 class LinalgOpConversion<IndexedGenericOp>
483     : public OpRewritePattern<IndexedGenericOp> {
484 public:
485   using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
486 
487   PatternMatchResult matchAndRewrite(IndexedGenericOp op,
488                                      PatternRewriter &rewriter) const override {
489     auto libraryCallName =
490         getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
491     if (!libraryCallName)
492       return this->matchFailure();
493 
494     // TODO(pifon, ntv): Use induction variables values instead of zeros, when
495     // IndexedGenericOp is tiled.
496     auto zero = rewriter.create<mlir::ConstantOp>(
497         op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
498     auto indexedGenericOp = cast<IndexedGenericOp>(op);
499     auto numLoops = indexedGenericOp.getNumLoops();
500     SmallVector<Value, 4> operands;
501     operands.reserve(numLoops + op.getNumOperands());
502     for (unsigned i = 0; i < numLoops; ++i) {
503       operands.push_back(zero);
504     }
505     for (auto operand : op.getOperands()) {
506       operands.push_back(operand);
507     }
508     rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
509                                               ArrayRef<Type>{}, operands);
510     return this->matchSuccess();
511   }
512 };
513 
514 /// A non-conversion rewrite pattern kicks in to convert CopyOp with
515 /// permutations into a sequence of TransposeOp and permutation-free CopyOp.
516 /// This interplays together with TransposeOpConversion and
517 /// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
518 class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
519 public:
520   using OpRewritePattern<CopyOp>::OpRewritePattern;
521 
522   PatternMatchResult matchAndRewrite(CopyOp op,
523                                      PatternRewriter &rewriter) const override {
524     Value in = op.input(), out = op.output();
525 
526     // If either inputPerm or outputPerm are non-identities, insert transposes.
527     auto inputPerm = op.inputPermutation();
528     if (inputPerm.hasValue() && !inputPerm->isIdentity())
529       in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
530                                                 AffineMapAttr::get(*inputPerm));
531     auto outputPerm = op.outputPermutation();
532     if (outputPerm.hasValue() && !outputPerm->isIdentity())
533       out = rewriter.create<linalg::TransposeOp>(
534           op.getLoc(), out, AffineMapAttr::get(*outputPerm));
535 
536     // If nothing was transposed, fail and let the conversion kick in.
537     if (in == op.input() && out == op.output())
538       return matchFailure();
539 
540     rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
541     return matchSuccess();
542   }
543 };
544 
545 /// Populate the given list with patterns that convert from Linalg to Standard.
546 static void
547 populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
548                                            MLIRContext *ctx) {
549   // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant
550   // attribute values such as kernel striding and dilation.
551   patterns.insert<CopyTransposeConversion, LinalgOpConversion<ConvOp>,
552                   LinalgOpConversion<CopyOp>, LinalgOpConversion<DotOp>,
553                   LinalgOpConversion<FillOp>, LinalgOpConversion<GenericOp>,
554                   LinalgOpConversion<IndexedGenericOp>,
555                   LinalgOpConversion<MatmulOp>, LinalgOpConversion<MatvecOp>>(
556       ctx);
557 }
558 
559 } // namespace
560 
561 /// Populate the given list with patterns that convert from Linalg to LLVM.
562 void mlir::populateLinalgToLLVMConversionPatterns(
563     LinalgTypeConverter &converter, OwningRewritePatternList &patterns,
564     MLIRContext *ctx) {
565   patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
566                   TransposeOpConversion, YieldOpConversion>(ctx, converter);
567 }
568 
569 namespace {
570 struct ConvertLinalgToLLVMPass : public ModulePass<ConvertLinalgToLLVMPass> {
571   void runOnModule() override;
572 };
573 } // namespace
574 
575 void ConvertLinalgToLLVMPass::runOnModule() {
576   auto module = getModule();
577 
578   // Convert to the LLVM IR dialect using the converter defined above.
579   OwningRewritePatternList patterns;
580   LinalgTypeConverter converter(&getContext());
581   populateAffineToStdConversionPatterns(patterns, &getContext());
582   populateLoopToStdConversionPatterns(patterns, &getContext());
583   populateStdToLLVMConversionPatterns(converter, patterns);
584   populateVectorToLLVMConversionPatterns(converter, patterns);
585   populateLinalgToStandardConversionPatterns(patterns, &getContext());
586   populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
587 
588   ConversionTarget target(getContext());
589   target.addLegalDialect<LLVM::LLVMDialect>();
590   target.addDynamicallyLegalOp<FuncOp>(
591       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
592   target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
593   if (failed(applyFullConversion(module, target, patterns, &converter)))
594     signalPassFailure();
595 }
596 
597 std::unique_ptr<OpPassBase<ModuleOp>>
598 mlir::linalg::createConvertLinalgToLLVMPass() {
599   return std::make_unique<ConvertLinalgToLLVMPass>();
600 }
601 
602 static PassRegistration<ConvertLinalgToLLVMPass> pass(
603     "convert-linalg-to-llvm",
604     "Convert the operations from the linalg dialect into the LLVM dialect");
605