1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
11 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
14 #include "mlir/Dialect/Linalg/Passes.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/Linalg/Utils/Utils.h"
17 #include "mlir/Dialect/SCF/EDSC/Builders.h"
18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/BlockAndValueMapping.h"
22 #include "mlir/Support/LLVM.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "mlir/Transforms/FoldUtils.h"
25 
26 using namespace mlir;
27 using namespace mlir::edsc;
28 using namespace mlir::edsc::intrinsics;
29 using namespace mlir::linalg;
30 
31 using edsc::op::operator+;
32 
33 static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b,
34                                                         Location loc,
35                                                         AffineMap map,
36                                                         ArrayRef<Value> vals) {
37   if (map.isEmpty())
38     return {};
39   assert(map.getNumSymbols() == 0);
40   assert(map.getNumInputs() == vals.size());
41   SmallVector<Value, 8> res;
42   res.reserve(map.getNumResults());
43   auto dims = map.getNumDims();
44   for (auto e : map.getResults()) {
45     auto exprMap = AffineMap::get(dims, 0, e);
46     SmallVector<Value, 4> operands(vals.begin(), vals.end());
47     canonicalizeMapAndOperands(&exprMap, &operands);
48     res.push_back(affine_apply(exprMap, operands));
49   }
50   return res;
51 }
52 
53 static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
54                                         Optional<AffineMap> permutation) {
55   return permutation ? applyMapToValues(ScopedContext::getBuilderRef(),
56                                         ScopedContext::getLocation(),
57                                         permutation.getValue(), ivs)
58                      : SmallVector<Value, 4>(ivs.begin(), ivs.end());
59 }
60 
61 // Creates a number of ranges equal to the number of results in `map`.
62 // The returned ranges correspond to the loop ranges, in the proper order, for
63 // which new loops will be created.
64 static SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc,
65                                             AffineMap map,
66                                             ArrayRef<Value> allViewSizes);
67 SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map,
68                                      ArrayRef<Value> allViewSizes) {
69   // Apply `map` to get view sizes in loop order.
70   auto sizes = applyMapToValues(b, loc, map, allViewSizes);
71   // Create a new range with the applied tile sizes.
72   ScopedContext scope(b, loc);
73   SmallVector<Value, 4> res;
74   for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) {
75     res.push_back(
76         linalg_range(std_constant_index(0), sizes[idx], std_constant_index(1)));
77   }
78   return res;
79 }
80 
81 template <typename IndexedValueType, typename OpType>
82 static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
83                                      ArrayRef<SmallVector<Value, 8>> indexing,
84                                      ArrayRef<Value> outputBuffers) {
85   auto &b = ScopedContext::getBuilderRef();
86   auto &block = op.region().front();
87   BlockAndValueMapping map;
88   map.map(block.getArguments(), indexedValues);
89   for (auto &op : block.without_terminator()) {
90     assert(op.getNumRegions() == 0 && "expected a non-nested region");
91     auto *newOp = b.clone(op, map);
92     map.map(op.getResults(), newOp->getResults());
93   }
94 
95   Operation &terminator = block.back();
96   assert(isa<YieldOp>(terminator) &&
97          "expected a yield op in the end of the region");
98   for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) {
99     IndexedValueType O(outputBuffers[i]);
100     O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i));
101   }
102 }
103 
104 // Returns a pair that contains input indices and output indices of a
105 // SingleInputPoolingOp `op`.
106 struct InputAndOutputIndices {
107   SmallVector<Value, 8> inputs;
108   SmallVector<Value, 8> outputs;
109 };
110 template <typename SingleInputPoolingOp>
111 static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs,
112                                                       SingleInputPoolingOp op) {
113   auto &b = ScopedContext::getBuilderRef();
114   auto loc = ScopedContext::getLocation();
115   auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>();
116   auto maps = llvm::to_vector<8>(
117       llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
118   return InputAndOutputIndices{
119       makeCanonicalAffineApplies(b, loc, maps[0], allIvs),
120       makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
121 }
122 
123 namespace {
124 
125 /// Emits the MLIR for the scalar part of the generic op by:
126 ///   1. Emitting load ops for each input and output view in order. This is
127 ///      achieved by applying the appropriate input or output map to the
128 ///      enclosing induction variables.
129 ///   2. Emitting a call to `op.fun()` that takes as arguments the scalars
130 ///      from point 1. above.
131 ///   3. Emitting store ops to store the results of 2. to the output
132 ///      views.
133 ///
134 /// An example output may resemble:
135 ///
136 /// ```
137 ///    scf.for %i = %c0 to %0 step %c1 {
138 ///      scf.for %j = %c0 to %1 step %c1 {
139 ///        scf.for %k = %c0 to %4 step %c1 {
140 ///          %11 = load %arg0[%i, %j] :
141 ///            memref<?x?xf32, stride_specification>
142 ///          %12 = load %arg1[%i, %j, %k] :
143 ///            memref<?x?x?xf32, stride_specification>
144 ///          %13 = load %arg2[%i, %k, %j] :
145 ///            memref<?x?x?xf32, stride_specification>
146 ///          %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
147 ///          store %14#0, %arg1[%i, %j, %k] :
148 ///            memref<?x?x?Xf32, stride_specification>
149 ///          store %14#1, %arg2[%i, %k, %j] :
150 ///            memref<?x?x?Xf32, stride_specification>
151 ///       }
152 ///      }
153 ///    }
154 /// ```
155 template <typename IndexedValueType, typename LinalgOpType>
156 class LinalgScopedEmitter {
157 public:
158   static void emitScalarImplementation(ArrayRef<Value> allIvs,
159                                        LinalgOpType linalgOp) {
160     assert(linalgOp.hasBufferSemantics() &&
161            "expected linalg op with buffer semantics");
162     auto &b = ScopedContext::getBuilderRef();
163     auto loc = ScopedContext::getLocation();
164     unsigned nInputs = linalgOp.getNumInputs();
165     unsigned nOutputs = linalgOp.getNumOutputs();
166     SmallVector<Value, 4> indexedValues;
167     indexedValues.reserve(nInputs + nOutputs);
168 
169     // TODO(mravishankar): Avoid the loads if the corresponding argument of the
170     // region has no uses.
171     // 1.a. Emit load from input views.
172     for (unsigned i = 0; i < nInputs; ++i) {
173       auto indexing = makeCanonicalAffineApplies(
174           b, loc, linalgOp.getInputIndexingMap(i), allIvs);
175       // Passing through IndexedValueType emits the proper load operation.
176       indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing));
177     }
178     // 1.b. Emit load from output views.
179     for (unsigned i = 0; i < nOutputs; ++i) {
180       auto indexing = makeCanonicalAffineApplies(
181           b, loc, linalgOp.getOutputIndexingMap(i), allIvs);
182       // Passing through IndexedValueType emits the proper load operation.
183       indexedValues.push_back(
184           IndexedValueType(linalgOp.getOutputBuffer(i))(indexing));
185     }
186 
187     // TODO(ntv): When a region inliner exists, use it.
188     // 2. Inline region, currently only works for a single basic block.
189     // 3. Emit store.
190     SmallVector<SmallVector<Value, 8>, 8> indexing;
191     SmallVector<Value, 8> outputBuffers;
192     for (unsigned i = 0; i < nOutputs; ++i) {
193       indexing.push_back(makeCanonicalAffineApplies(
194           b, loc, linalgOp.getOutputIndexingMap(i), allIvs));
195       outputBuffers.push_back(linalgOp.getOutputBuffer(i));
196     }
197     inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues,
198                                                indexing, outputBuffers);
199   }
200 };
201 
202 template <typename IndexedValueType>
203 class LinalgScopedEmitter<IndexedValueType, CopyOp> {
204 public:
205   static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
206     assert(copyOp.hasBufferSemantics() &&
207            "expected linalg op with buffer semantics");
208     auto nPar = copyOp.getNumParallelLoops();
209     assert(nPar == allIvs.size());
210     auto inputIvs =
211         permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation());
212     auto outputIvs =
213         permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
214     SmallVector<Value, 8> iivs(inputIvs.begin(), inputIvs.end());
215     SmallVector<Value, 8> oivs(outputIvs.begin(), outputIvs.end());
216     IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0));
217     // Emit the proper scalar assignment, whether we are dealing with a 0-D or
218     // an n-D loop nest; with or without permutations.
219     // clang-format off
220     nPar > 0 ? O(oivs) = I(iivs) :
221                O() = I();
222     // clang-format on
223   }
224 };
225 
226 template <typename IndexedValueType>
227 class LinalgScopedEmitter<IndexedValueType, FillOp> {
228 public:
229   static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
230     assert(fillOp.hasBufferSemantics() &&
231            "expected linalg op with buffer semantics");
232     auto nPar = fillOp.getNumParallelLoops();
233     assert(nPar == allIvs.size());
234     auto ivs = SmallVector<Value, 4>(allIvs.begin(), allIvs.begin() + nPar);
235     IndexedValueType O(fillOp.getOutputBuffer(0));
236     // Emit the proper scalar assignment, whether we are dealing with a 0-D or
237     // an n-D loop nest; with or without permutations.
238     nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
239   }
240 };
241 
242 template <typename IndexedValueType>
243 class LinalgScopedEmitter<IndexedValueType, DotOp> {
244 public:
245   static void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) {
246     assert(dotOp.hasBufferSemantics() &&
247            "expected linalg op with buffer semantics");
248     assert(allIvs.size() == 1);
249     Value r_i(allIvs[0]);
250     IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)),
251         C(dotOp.getOutputBuffer(0));
252     // Emit scalar form.
253     C() = C() + A(r_i) * B(r_i);
254   }
255 };
256 
257 template <typename IndexedValueType>
258 class LinalgScopedEmitter<IndexedValueType, MatvecOp> {
259 public:
260   static void emitScalarImplementation(ArrayRef<Value> allIvs,
261                                        MatvecOp matvecOp) {
262     assert(matvecOp.hasBufferSemantics() &&
263            "expected linalg op with buffer semantics");
264     assert(allIvs.size() == 2);
265     Value i(allIvs[0]), r_j(allIvs[1]);
266     IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
267         C(matvecOp.getOutputBuffer(0));
268     // Emit scalar form.
269     C(i) = C(i) + A(i, r_j) * B(r_j);
270   }
271 };
272 
273 template <typename IndexedValueType>
274 class LinalgScopedEmitter<IndexedValueType, MatmulOp> {
275 public:
276   static void emitScalarImplementation(ArrayRef<Value> allIvs,
277                                        MatmulOp matmulOp) {
278     assert(matmulOp.hasBufferSemantics() &&
279            "expected linalg op with buffer semantics");
280     assert(allIvs.size() == 3);
281     Value i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
282     IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
283         C(matmulOp.getOutputBuffer(0));
284     // Emit scalar form.
285     C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
286   }
287 };
288 
289 template <typename IndexedValueType>
290 class LinalgScopedEmitter<IndexedValueType, ConvOp> {
291 public:
292   /// Returns the input value of convOp. If the indices in `imIdx` is out of
293   /// boundary, returns 0 instead.
294   static Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
295                               MutableArrayRef<Value> imIdx) {
296     // TODO(ntv): add a level of indirection to linalg.generic.
297     if (!convOp.padding())
298       return im(imIdx);
299 
300     auto *context = ScopedContext::getContext();
301     Value zeroIndex = std_constant_index(0);
302     SmallVector<Value, 8> conds;
303     SmallVector<Value, 8> clampedImIdx;
304     for (auto iter : llvm::enumerate(imIdx)) {
305       int idx = iter.index();
306       auto dim = iter.value();
307       // Only need to iterate over the window dimensions.
308       if (idx == 0 || idx == static_cast<int>(imIdx.size()) - 1) {
309         clampedImIdx.push_back(dim);
310         continue;
311       }
312 
313       using edsc::op::operator<;
314       using edsc::op::operator>=;
315       using edsc::op::operator||;
316       Value leftOutOfBound = dim < zeroIndex;
317       if (conds.empty())
318         conds.push_back(leftOutOfBound);
319       else
320         conds.push_back(conds.back() || leftOutOfBound);
321       Value rightBound = std_dim(convOp.input(), idx);
322       conds.push_back(conds.back() || (dim >= rightBound));
323 
324       // When padding is involved, the indices will only be shifted to negative,
325       // so having a max op is enough.
326       auto maxMap = AffineMap::get(/*dimCount=*/1, 0,
327                                    {getAffineDimExpr(/*position=*/0, context),
328                                     getAffineConstantExpr(0, context)},
329                                    context);
330       clampedImIdx.push_back(
331           affine_max(dim.getType(), maxMap, ValueRange{dim}));
332     }
333 
334     auto &b = ScopedContext::getBuilderRef();
335     Type type = convOp.input().getType().cast<MemRefType>().getElementType();
336     Value zero = std_constant(type, b.getZeroAttr(type));
337     Value readInput = im(clampedImIdx);
338     return conds.empty() ? readInput
339                          : (Value)std_select(conds.back(), zero, readInput);
340   }
341 
342   static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
343     assert(convOp.hasBufferSemantics() &&
344            "expected linalg op with buffer semantics");
345     auto &b = ScopedContext::getBuilderRef();
346     auto loc = ScopedContext::getLocation();
347     auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>();
348     auto maps = llvm::to_vector<8>(llvm::map_range(
349         mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
350     SmallVector<Value, 8> fIdx(
351         makeCanonicalAffineApplies(b, loc, maps[0], allIvs));
352     SmallVector<Value, 8> imIdx(
353         makeCanonicalAffineApplies(b, loc, maps[1], allIvs));
354     SmallVector<Value, 8> oIdx(
355         makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
356 
357     // Padded conv involves an affine.max in the memory access which is not
358     // allowed by affine.load. Override to always use an StdIndexedValue.
359     StdIndexedValue I(convOp.input());
360     IndexedValueType F(convOp.filter()), O(convOp.output());
361 
362     // Emit scalar form.
363     Value paddedInput = getConvOpInput(convOp, I, imIdx);
364     O(oIdx) += F(fIdx) * paddedInput;
365   }
366 };
367 
368 template <typename IndexedValueType>
369 class LinalgScopedEmitter<IndexedValueType, PoolingMaxOp> {
370 public:
371   static void emitScalarImplementation(ArrayRef<Value> allIvs,
372                                        PoolingMaxOp op) {
373     auto indices = getInputAndOutputIndices(allIvs, op);
374     // Emit scalar form.
375     Value lhs = std_load(op.output(), indices.outputs);
376     Value rhs = std_load(op.input(), indices.inputs);
377     using edsc::op::operator>;
378     Value maxValue = std_select(lhs > rhs, lhs, rhs);
379     std_store(maxValue, op.output(), indices.outputs);
380   }
381 };
382 
383 template <typename IndexedValueType>
384 class LinalgScopedEmitter<IndexedValueType, PoolingMinOp> {
385 public:
386   static void emitScalarImplementation(ArrayRef<Value> allIvs,
387                                        PoolingMinOp op) {
388     auto indices = getInputAndOutputIndices(allIvs, op);
389     // Emit scalar form.
390     Value lhs = std_load(op.output(), indices.outputs);
391     Value rhs = std_load(op.input(), indices.inputs);
392     using edsc::op::operator<;
393     Value minValue = std_select(lhs < rhs, lhs, rhs);
394     std_store(minValue, op.output(), indices.outputs);
395   }
396 };
397 
398 template <typename IndexedValueType>
399 class LinalgScopedEmitter<IndexedValueType, PoolingSumOp> {
400 public:
401   static void emitScalarImplementation(ArrayRef<Value> allIvs,
402                                        PoolingSumOp op) {
403     auto indices = getInputAndOutputIndices(allIvs, op);
404     IndexedValueType input(op.input()), output(op.output());
405 
406     // Emit scalar form.
407     output(indices.outputs) += input(indices.inputs);
408   }
409 };
410 
411 /// Emits the MLIR for the scalar part of the indexed generic op by:
412 ///   1. Emitting load ops for each input and output view in order. This is
413 ///      achieved by applying the appropriate input or output map to the
414 ///      enclosing induction variables.
415 ///   2. Emitting a call to `op.fun()` that takes as arguments the induction
416 ///      variables and the scalars from point 1. above.
417 ///   3. Emitting store ops to store the results of 2. to the output views.
418 ///
419 /// An example output may resemble:
420 ///
421 /// ```
422 ///    scf.for %i = %c0 to %0 step %c1 {
423 ///      scf.for %j = %c0 to %1 step %c1 {
424 ///        scf.for %k = %c0 to %4 step %c1 {
425 ///          %11 = load %arg0[%i, %j] :
426 ///            memref<?x?xf32, stride_specification>
427 ///          %12 = load %arg1[%i, %j, %k] :
428 ///            memref<?x?x?xf32, stride_specification>
429 ///          %13 = load %arg2[%i, %k, %j] :
430 ///            memref<?x?x?xf32, stride_specification>
431 ///          %14:2 = call @foo(%i, %j, %k, %11, %12, %13) :
432 ///            (index, index, index, f32, f32, f32) -> (f32, f32)
433 ///          store %14#0, %arg1[%i, %j, %k] :
434 ///            memref<?x?x?Xf32, stride_specification>
435 ///          store %14#1, %arg2[%i, %k, %j] :
436 ///            memref<?x?x?Xf32, stride_specification>
437 ///       }
438 ///      }
439 ///    }
440 /// ```
441 template <typename IndexedValueType>
442 class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
443 public:
444   static void emitScalarImplementation(ArrayRef<Value> allIvs,
445                                        IndexedGenericOp indexedGenericOp) {
446     assert(indexedGenericOp.hasBufferSemantics() &&
447            "expected linalg op with buffer semantics");
448     auto &b = ScopedContext::getBuilderRef();
449     auto loc = ScopedContext::getLocation();
450     unsigned nInputs = indexedGenericOp.getNumInputs();
451     unsigned nOutputs = indexedGenericOp.getNumOutputs();
452     unsigned nLoops = allIvs.size();
453     SmallVector<Value, 4> indexedValues;
454     indexedValues.reserve(nLoops + nInputs + nOutputs);
455     for (unsigned i = 0; i < nLoops; ++i)
456       indexedValues.push_back(allIvs[i]);
457 
458     // TODO(mravishankar): Avoid the loads if the corresponding argument of the
459     // region has no uses.
460     // 1.a. Emit load from input views.
461     for (unsigned i = 0; i < nInputs; ++i) {
462       auto indexing = makeCanonicalAffineApplies(
463           b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs);
464       // Pass input i through IndexedValueType emits the proper load operation.
465       indexedValues.push_back(
466           IndexedValueType(indexedGenericOp.getInput(i))(indexing));
467     }
468     // 1.b. Emit load from output views.
469     for (unsigned i = 0; i < nOutputs; ++i) {
470       auto indexing = makeCanonicalAffineApplies(
471           b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs);
472       // Pass output i through IndexedValueType emits the proper load operation.
473       indexedValues.push_back(
474           IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing));
475     }
476 
477     // TODO(ntv): When a region inliner exists, use it.
478     // 2. Inline region, currently only works for a single basic block.
479     // 3. Emit store.
480     SmallVector<SmallVector<Value, 8>, 8> indexing;
481     SmallVector<Value, 8> outputBuffers;
482     for (unsigned i = 0; i < nOutputs; ++i) {
483       indexing.push_back(makeCanonicalAffineApplies(
484           b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
485       outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i));
486     }
487     inlineRegionAndEmitStore<IndexedValueType>(indexedGenericOp, indexedValues,
488                                                indexing, outputBuffers);
489   }
490 };
491 
492 namespace {
493 /// Helper struct to generate the loop nest for the op. This factored out here
494 /// to be able to partially specialize this for different LoopTy.
495 template <typename LoopTy, typename ConcreteOpTy>
496 class GenerateLoopNest {
497 public:
498   using IndexedValueTy =
499       typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
500                                 AffineIndexedValue, StdIndexedValue>::type;
501   static void doit(ConcreteOpTy linalgOp, ArrayRef<Value> loopRanges,
502                    MutableArrayRef<Value> allIvs) {
503     GenericLoopNestRangeBuilder<LoopTy>(allIvs, loopRanges)([&] {
504       SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
505       LinalgScopedEmitter<IndexedValueTy,
506                           ConcreteOpTy>::emitScalarImplementation(allIvValues,
507                                                                   linalgOp);
508     });
509   }
510 };
511 
512 /// Generates loop nest using scf.parallel. scf.parallel is only used for the
513 /// outer parallel loops. All other loops are generated using scf.for
514 /// operation.
515 template <typename ConcreteOpTy>
516 class GenerateLoopNest<scf::ParallelOp, ConcreteOpTy> {
517 public:
518   using IndexedValueTy = StdIndexedValue;
519 
520   static void doit(ConcreteOpTy linalgOp, ArrayRef<Value> loopRanges,
521                    MutableArrayRef<Value> allIvs) {
522     // Only generate scf.parallel for outer consecutive "parallel"
523     // iterator_types.
524     // TODO(ravishankarm): Generate scf.parallel for all "parallel" iterator
525     // types, not just the outer most ones. Also handle "reduction" iterator
526     // types.
527     auto nOuterPar = linalgOp.iterator_types()
528                          .getValue()
529                          .take_while([](Attribute attr) {
530                            return attr.cast<StringAttr>().getValue() ==
531                                   getParallelIteratorTypeName();
532                          })
533                          .size();
534     // If there are no outer parallel loops, then number of loop ops is same as
535     // the number of loops, and they are all scf.for ops.
536     if (nOuterPar) {
537       GenericLoopNestRangeBuilder<scf::ParallelOp>(
538           allIvs.take_front(nOuterPar), loopRanges.take_front(nOuterPar))([&] {
539         GenericLoopNestRangeBuilder<scf::ForOp>(
540             allIvs.drop_front(nOuterPar),
541             loopRanges.drop_front(nOuterPar))([&] {
542           SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
543           LinalgScopedEmitter<StdIndexedValue, ConcreteOpTy>::
544               emitScalarImplementation(allIvValues, linalgOp);
545         });
546       });
547     } else {
548       // If there are no parallel loops then fallback to generating all scf.for
549       // operations.
550       GenericLoopNestRangeBuilder<scf::ForOp>(allIvs, loopRanges)([&] {
551         SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
552         LinalgScopedEmitter<StdIndexedValue,
553                             ConcreteOpTy>::emitScalarImplementation(allIvValues,
554                                                                     linalgOp);
555       });
556     }
557   }
558 };
559 } // namespace
560 
561 template <typename LoopTy, typename ConcreteOpTy>
562 Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
563   using Impl = GenerateLoopNest<LoopTy, ConcreteOpTy>;
564   using IndexedValueTy =
565       typename GenerateLoopNest<LoopTy, ConcreteOpTy>::IndexedValueTy;
566 
567   ScopedContext scope(builder, op->getLoc());
568 
569   // The flattened loopToOperandRangesMaps is expected to be an invertible
570   // permutation map (which is asserted in the inverse calculation).
571   auto linalgOp = cast<ConcreteOpTy>(op);
572   assert(linalgOp.hasBufferSemantics() &&
573          "expected linalg op with buffer semantics");
574   auto nPar = linalgOp.getNumParallelLoops();
575   auto nRed = linalgOp.getNumReductionLoops();
576   auto nWin = linalgOp.getNumWindowLoops();
577   auto nLoops = nPar + nRed + nWin;
578   auto mapsRange =
579       linalgOp.indexing_maps().template getAsRange<AffineMapAttr>();
580   auto maps = llvm::to_vector<8>(
581       llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
582   AffineMap invertedMap = inversePermutation(concatAffineMaps(maps));
583   if (!invertedMap)
584     return {};
585   if (invertedMap.isEmpty()) {
586     LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
587         {}, linalgOp);
588     return LinalgLoops();
589   }
590 
591   SmallVector<Value, 4> allIvs(nLoops);
592   auto loopRanges =
593       emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap,
594                      getViewSizes(builder, linalgOp));
595   assert(loopRanges.size() == allIvs.size());
596   Impl::doit(linalgOp, loopRanges, allIvs);
597   // Number of loop ops might be different from the number of ivs since some
598   // loops like affine.parallel and scf.parallel have multiple ivs.
599   llvm::SetVector<Operation *> loopSet;
600   for (Value iv : allIvs) {
601     if (!iv)
602       return {};
603     // The induction variable is a block argument of the entry block of the
604     // loop operation.
605     BlockArgument ivVal = iv.dyn_cast<BlockArgument>();
606     if (!ivVal)
607       return {};
608     loopSet.insert(ivVal.getOwner()->getParentOp());
609   }
610   LinalgLoops loops(loopSet.begin(), loopSet.end());
611   return loops;
612 }
613 
614 template <typename LoopType, typename ConcreteOp>
615 class LinalgRewritePattern : public RewritePattern {
616 public:
617   explicit LinalgRewritePattern(MLIRContext *context)
618       : RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
619 
620   LogicalResult matchAndRewrite(Operation *op,
621                                 PatternRewriter &rewriter) const override {
622     if (!linalgOpToLoopsImpl<LoopType, ConcreteOp>(op, rewriter))
623       return failure();
624     rewriter.eraseOp(op);
625     return success();
626   }
627 };
628 
629 /// Helper classes for type list expansion.
630 template <typename LoopType, typename... LinalgOps>
631 class RewritePatternList;
632 
633 template <typename LoopType>
634 class RewritePatternList<LoopType> {
635 public:
636   static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {}
637 };
638 
639 template <typename LoopType, typename ConcreteOp, typename... LinalgOps>
640 class RewritePatternList<LoopType, ConcreteOp, LinalgOps...> {
641 public:
642   static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {
643     patterns.insert<LinalgRewritePattern<LoopType, ConcreteOp>>(ctx);
644     RewritePatternList<LoopType, LinalgOps...>::build(patterns, ctx);
645   }
646 };
647 
648 /// Populate the given list with patterns that convert from Linalg to loops.
649 template <typename LoopType>
650 void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
651   RewritePatternList<LoopType,
652 #define GET_OP_LIST
653 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
654                      >::build(patterns, ctx);
655 }
656 
657 /// Local folding pattern for AffineApplyOp that we can apply greedily.
658 /// This replaces AffineApplyOp by the proper value in cases where the
659 /// associated map is trivial.
660 /// A trivial map here is defined as a map with a single result and either:
661 ///   1. Zero operand + returns a single AffineConstantExpr
662 ///   2. One operand + returns a single AffineDimExpr
663 ///   3. One operand + returns a single AffineSymbolExpr
664 //
665 /// In the first case, the AffineApplyOp is replaced by a new constant. In the
666 /// other cases, it is replaced by its unique operand.
667 struct FoldAffineOp : public RewritePattern {
668   FoldAffineOp(MLIRContext *context)
669       : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {}
670 
671   LogicalResult matchAndRewrite(Operation *op,
672                                 PatternRewriter &rewriter) const override {
673     AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op);
674     auto map = affineApplyOp.getAffineMap();
675     if (map.getNumResults() != 1 || map.getNumInputs() > 1)
676       return failure();
677 
678     AffineExpr expr = map.getResult(0);
679     if (map.getNumInputs() == 0) {
680       if (auto val = expr.dyn_cast<AffineConstantExpr>()) {
681         rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue());
682         return success();
683       }
684       return failure();
685     }
686     if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) {
687       rewriter.replaceOp(op, op->getOperand(0));
688       return success();
689     }
690     return failure();
691   }
692 };
693 } // namespace
694 
695 template <typename LoopType>
696 static void lowerLinalgToLoopsImpl(Operation *op, MLIRContext *context) {
697   OwningRewritePatternList patterns;
698   // Canonicalization and folding patterns applied greedily allow cleaning up
699   // the emitted IR on the fly.
700   // TODO(ntv) fold view and subview ops?
701   FillRewritePatterns<LoopType>(patterns, context);
702   DimOp::getCanonicalizationPatterns(patterns, context);
703   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
704   patterns.insert<FoldAffineOp>(context);
705   // Just apply the patterns greedily.
706   applyPatternsAndFoldGreedily(op, patterns);
707 }
708 
709 namespace {
710 struct LowerToAffineLoops
711     : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
712   void runOnFunction() override {
713     lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext());
714   }
715 };
716 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
717   void runOnFunction() override {
718     lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext());
719   }
720 };
721 struct LowerToParallelLoops
722     : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
723   void runOnFunction() override {
724     lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), &getContext());
725   }
726 };
727 } // namespace
728 
729 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() {
730   return std::make_unique<LowerToLoops>();
731 }
732 
733 std::unique_ptr<OperationPass<FuncOp>>
734 mlir::createConvertLinalgToParallelLoopsPass() {
735   return std::make_unique<LowerToParallelLoops>();
736 }
737 
738 std::unique_ptr<OperationPass<FuncOp>>
739 mlir::createConvertLinalgToAffineLoopsPass() {
740   return std::make_unique<LowerToAffineLoops>();
741 }
742 
743 /// Emits a loop nest with the proper body for `op`.
744 template <typename LoopTy, typename ConcreteOp>
745 Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
746                                                          Operation *op) {
747   return linalgOpToLoopsImpl<LoopTy, ConcreteOp>(op, builder);
748 }
749 
750 /// Emits a loop nest of `scf.for` with the proper body for `op`.
751 template <typename ConcreteOp>
752 LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) {
753   Optional<LinalgLoops> loops =
754       linalgLowerOpToLoops<scf::ForOp, ConcreteOp>(builder, op);
755   return loops ? success() : failure();
756 }
757 
758 /// Emits a loop nest of `affine.for` with the proper body for `op`.
759 template <typename ConcreteOp>
760 LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder,
761                                                   Operation *op) {
762   Optional<LinalgLoops> loops =
763       linalgLowerOpToLoops<AffineForOp, ConcreteOp>(builder, op);
764   return loops ? success() : failure();
765 }
766 
767 /// Emits a loop nest of `scf.parallel` with the proper body for `op`.
768 template <typename ConcreteOp>
769 LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder,
770                                                     Operation *op) {
771   Optional<LinalgLoops> loops =
772       linalgLowerOpToLoops<scf::ParallelOp, ConcreteOp>(builder, op);
773   return loops ? success() : failure();
774 }
775 
776 // TODO Need to make these instantiations more future-proof to avoid the need to
777 // update as soon as we add new ops.
778 #define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE)                                \
779   template LogicalResult mlir::linalg::linalgOpToLoops<OP_TYPE>(               \
780       OpBuilder & builder, Operation * op);                                    \
781   template LogicalResult mlir::linalg::linalgOpToAffineLoops<OP_TYPE>(         \
782       OpBuilder & builder, Operation * op);                                    \
783   template LogicalResult mlir::linalg::linalgOpToParallelLoops<OP_TYPE>(       \
784       OpBuilder & builder, Operation * op);                                    \
785   template Optional<LinalgLoops>                                               \
786       mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp, OP_TYPE>(            \
787           OpBuilder & builder, Operation * op);
788 
789 INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp)
790 INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp)
791 INSTANTIATE_LINALG_OP_TO_LOOPS(DotOp)
792 INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp)
793 INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp)
794 INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp)
795 INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMaxOp)
796 INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMinOp)
797 INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingSumOp)
798 INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp)
799 INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp)
800