1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
11 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
14 #include "mlir/Dialect/Linalg/Passes.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/Linalg/Utils/Utils.h"
17 #include "mlir/Dialect/SCF/EDSC/Builders.h"
18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/BlockAndValueMapping.h"
22 #include "mlir/Support/LLVM.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "mlir/Transforms/FoldUtils.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 using namespace mlir;
30 using namespace mlir::edsc;
31 using namespace mlir::edsc::intrinsics;
32 using namespace mlir::linalg;
33 
34 using edsc::op::operator+;
35 
36 static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b,
37                                                         Location loc,
38                                                         AffineMap map,
39                                                         ArrayRef<Value> vals) {
40   if (map.isEmpty())
41     return {};
42 
43   assert(map.getNumInputs() == vals.size());
44   SmallVector<Value, 8> res;
45   res.reserve(map.getNumResults());
46   auto dims = map.getNumDims();
47   for (auto e : map.getResults()) {
48     auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e);
49     SmallVector<Value, 4> operands(vals.begin(), vals.end());
50     canonicalizeMapAndOperands(&exprMap, &operands);
51     res.push_back(affine_apply(exprMap, operands));
52   }
53   return res;
54 }
55 
56 static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
57                                         Optional<AffineMap> permutation) {
58   return permutation ? applyMapToValues(ScopedContext::getBuilderRef(),
59                                         ScopedContext::getLocation(),
60                                         permutation.getValue(), ivs)
61                      : SmallVector<Value, 4>(ivs.begin(), ivs.end());
62 }
63 
64 template <typename IndexedValueType, typename OpType>
65 static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
66                                      ArrayRef<SmallVector<Value, 8>> indexing,
67                                      ArrayRef<Value> outputBuffers) {
68   assert(op->getNumRegions() == 1 && "Expected single region op");
69   auto &b = ScopedContext::getBuilderRef();
70   auto &block = op->getRegion(0).front();
71   BlockAndValueMapping map;
72   map.map(block.getArguments(), indexedValues);
73   for (auto &op : block.without_terminator()) {
74     auto *newOp = b.clone(op, map);
75     map.map(op.getResults(), newOp->getResults());
76   }
77 
78   Operation &terminator = block.back();
79   assert(isa<linalg::YieldOp>(terminator) &&
80          "expected a yield op in the end of the region");
81   for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) {
82     IndexedValueType O(outputBuffers[i]);
83     O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i));
84   }
85 }
86 
87 // Returns a pair that contains input indices and output indices of a
88 // SingleInputPoolingOp `op`.
89 struct InputAndOutputIndices {
90   SmallVector<Value, 8> inputs;
91   SmallVector<Value, 8> outputs;
92 };
93 template <typename SingleInputPoolingOp>
94 static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs,
95                                                       SingleInputPoolingOp op) {
96   auto &b = ScopedContext::getBuilderRef();
97   auto loc = ScopedContext::getLocation();
98   auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>();
99   auto maps = llvm::to_vector<8>(
100       llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
101   return InputAndOutputIndices{
102       makeCanonicalAffineApplies(b, loc, maps[0], allIvs),
103       makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
104 }
105 
106 /// Emits the MLIR for the scalar part of the generic op by:
107 ///   1. Emitting load ops for each input and output view in order. This is
108 ///      achieved by applying the appropriate input or output map to the
109 ///      enclosing induction variables.
110 ///   2. Emitting a call to `op.fun()` that takes as arguments the scalars
111 ///      from point 1. above.
112 ///   3. Emitting store ops to store the results of 2. to the output
113 ///      views.
114 ///
115 /// An example output may resemble:
116 ///
117 /// ```
118 ///    scf.for %i = %c0 to %0 step %c1 {
119 ///      scf.for %j = %c0 to %1 step %c1 {
120 ///        scf.for %k = %c0 to %4 step %c1 {
121 ///          %11 = load %arg0[%i, %j] :
122 ///            memref<?x?xf32, stride_specification>
123 ///          %12 = load %arg1[%i, %j, %k] :
124 ///            memref<?x?x?xf32, stride_specification>
125 ///          %13 = load %arg2[%i, %k, %j] :
126 ///            memref<?x?x?xf32, stride_specification>
127 ///          %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
128 ///          store %14#0, %arg1[%i, %j, %k] :
129 ///            memref<?x?x?Xf32, stride_specification>
130 ///          store %14#1, %arg2[%i, %k, %j] :
131 ///            memref<?x?x?Xf32, stride_specification>
132 ///       }
133 ///      }
134 ///    }
135 /// ```
136 template <typename IndexedValueType>
137 static void emitScalarImplementation(ArrayRef<Value> allIvs,
138                                      LinalgOp linalgOp) {
139   assert(linalgOp.hasBufferSemantics() &&
140          "expected linalg op with buffer semantics");
141   auto &b = ScopedContext::getBuilderRef();
142   auto loc = ScopedContext::getLocation();
143   unsigned nInputs = linalgOp.getNumInputs();
144   unsigned nOutputs = linalgOp.getNumOutputs();
145   SmallVector<Value, 4> indexedValues;
146   indexedValues.reserve(nInputs + nOutputs);
147 
148   auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end());
149 
150   // TODO: Avoid the loads if the corresponding argument of the
151   // region has no uses.
152   // 1.a. Emit load from input views.
153   for (unsigned i = 0; i < nInputs; ++i) {
154     auto indexing = makeCanonicalAffineApplies(
155         b, loc, linalgOp.getInputIndexingMap(i), allIvsPlusDims);
156     // Passing through IndexedValueType emits the proper load operation.
157     indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing));
158   }
159   // 1.b. Emit load from output views.
160   for (unsigned i = 0; i < nOutputs; ++i) {
161     auto indexing = makeCanonicalAffineApplies(
162         b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims);
163     // Passing through IndexedValueType emits the proper load operation.
164     indexedValues.push_back(
165         IndexedValueType(linalgOp.getOutputBuffer(i))(indexing));
166   }
167 
168   // TODO: When a region inliner exists, use it.
169   // 2. Inline region, currently only works for a single basic block.
170   // 3. Emit store.
171   SmallVector<SmallVector<Value, 8>, 8> indexing;
172   SmallVector<Value, 8> outputBuffers;
173   for (unsigned i = 0; i < nOutputs; ++i) {
174     indexing.push_back(makeCanonicalAffineApplies(
175         b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims));
176     outputBuffers.push_back(linalgOp.getOutputBuffer(i));
177   }
178   inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues, indexing,
179                                              outputBuffers);
180 }
181 
182 template <typename IndexedValueType>
183 static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
184   assert(copyOp.hasBufferSemantics() &&
185          "expected linalg op with buffer semantics");
186   auto nPar = copyOp.getNumParallelLoops();
187   assert(nPar == allIvs.size());
188   auto inputIvs =
189       permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation());
190   auto outputIvs =
191       permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
192   SmallVector<Value, 8> iivs(inputIvs.begin(), inputIvs.end());
193   SmallVector<Value, 8> oivs(outputIvs.begin(), outputIvs.end());
194   IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0));
195   // Emit the proper scalar assignment, whether we are dealing with a 0-D or
196   // an n-D loop nest; with or without permutations.
197   // clang-format off
198     nPar > 0 ? O(oivs) = I(iivs) :
199                O() = I();
200   // clang-format on
201 }
202 
203 template <typename IndexedValueType>
204 static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
205   assert(fillOp.hasBufferSemantics() &&
206          "expected linalg op with buffer semantics");
207   auto nPar = fillOp.getNumParallelLoops();
208   assert(nPar == allIvs.size());
209   auto ivs = SmallVector<Value, 4>(allIvs.begin(), allIvs.begin() + nPar);
210   IndexedValueType O(fillOp.getOutputBuffer(0));
211   // Emit the proper scalar assignment, whether we are dealing with a 0-D or
212   // an n-D loop nest; with or without permutations.
213   nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
214 }
215 
216 // Create a padded view into the given `input` tensor using the 'indices'
217 // to access the tensor. `skipPadding` lists the dimensions for which no padding
218 // is needed e.g. the non-spatial dimensions for convolutions.
219 template <typename IndexedValueType>
220 Value getPaddedInput(Value input, ArrayRef<Value> indices,
221                      ArrayRef<int> skipPadding, Value padValue) {
222   // TODO: add a level of indirection to linalg.generic.
223 
224   IndexedValueType indexedInput(input);
225 
226   auto *context = ScopedContext::getContext();
227   Value zeroIndex = std_constant_index(0);
228   SmallVector<Value, 8> conds;
229   SmallVector<Value, 8> clampedImIdx;
230   for (auto iter : llvm::enumerate(indices)) {
231     int idx = iter.index();
232     auto dim = iter.value();
233     if (is_contained(skipPadding, idx)) {
234       clampedImIdx.push_back(dim);
235       continue;
236     }
237 
238     using edsc::op::sge;
239     using edsc::op::slt;
240     using edsc::op::operator||;
241     Value leftOutOfBound = slt(dim, zeroIndex);
242     if (conds.empty())
243       conds.push_back(leftOutOfBound);
244     else
245       conds.push_back(conds.back() || leftOutOfBound);
246     Value rightBound = std_dim(input, idx);
247     conds.push_back(conds.back() || (sge(dim, rightBound)));
248 
249     // When padding is involved, the indices will only be shifted to negative,
250     // so having a max op is enough.
251     auto maxMap = AffineMap::get(/*dimCount=*/1, 0,
252                                  {getAffineDimExpr(/*position=*/0, context),
253                                   getAffineConstantExpr(0, context)},
254                                  context);
255     clampedImIdx.push_back(affine_max(dim.getType(), maxMap, ValueRange{dim}));
256   }
257 
258   Value readInput = indexedInput(clampedImIdx);
259   return conds.empty() ? readInput
260                        : (Value)std_select(conds.back(), padValue, readInput);
261 }
262 
263 namespace {
264 
265 /// The padding value for a given Op depends on the semantics of the Op.
266 /// The identity value for ConvOp and PoolingSumOp is 0, for PoolingMaxOp is
267 /// -inf or minInt and for PoolingMinOp is inf or maxInt.
268 template <typename OpType>
269 Attribute getPadValueAttr(Type type) {
270   llvm_unreachable("Unexpected op type for getPadValueAttr");
271   return {};
272 }
273 
274 template <>
275 Attribute getPadValueAttr<PoolingMaxOp>(Type type) {
276   auto &b = ScopedContext::getBuilderRef();
277   if (auto floatType = type.dyn_cast<FloatType>()) {
278     return b.getFloatAttr(
279         floatType,
280         APFloat::getInf(floatType.getFloatSemantics(), /*Negative*/ true));
281   }
282   if (auto intType = type.dyn_cast<IntegerType>()) {
283     unsigned width = intType.getWidth();
284     // The select instruction used to lower the PoolingMin uses a signed
285     // comparison, use a signed constant irrespective of the signedness of the
286     // integer type.
287     return b.getIntegerAttr(intType, APInt::getSignedMinValue(width));
288   }
289   llvm_unreachable("Unsupported data type for PoolingMaxOp");
290   return {};
291 }
292 
293 template <>
294 Attribute getPadValueAttr<PoolingMinOp>(Type type) {
295   auto &b = ScopedContext::getBuilderRef();
296   if (auto floatType = type.dyn_cast<FloatType>()) {
297     return b.getFloatAttr(floatType,
298                           APFloat::getInf(floatType.getFloatSemantics()));
299   }
300   if (auto intType = type.dyn_cast<IntegerType>()) {
301     unsigned width = intType.getWidth();
302     // The select instruction used to lower the PoolingMin uses a signed
303     // comparison, use a signed constant irrespective of the signedness of the
304     // integer type.
305     return b.getIntegerAttr(intType, APInt::getSignedMaxValue(width));
306   }
307   llvm_unreachable("Unsupported data type for PoolingMinOp");
308   return {};
309 }
310 
311 template <>
312 Attribute getPadValueAttr<PoolingSumOp>(Type type) {
313   auto &b = ScopedContext::getBuilderRef();
314   return b.getZeroAttr(type);
315 }
316 
317 template <>
318 Attribute getPadValueAttr<ConvOp>(Type type) {
319   auto &b = ScopedContext::getBuilderRef();
320   return b.getZeroAttr(type);
321 }
322 
323 } // namespace
324 
325 /// Returns true is `convOp` has a non-zero padding.
326 static bool hasPadding(ConvOp convOp) {
327   for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) {
328     if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0)
329       return true;
330   }
331   return false;
332 }
333 
334 template <typename IndexedValueType>
335 static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
336   assert(convOp.hasBufferSemantics() &&
337          "expected linalg op with buffer semantics");
338   auto &b = ScopedContext::getBuilderRef();
339   auto loc = ScopedContext::getLocation();
340   auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>();
341   auto maps = llvm::to_vector<8>(
342       llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
343   SmallVector<Value, 8> fIdx(
344       makeCanonicalAffineApplies(b, loc, maps[0], allIvs));
345   SmallVector<Value, 8> imIdx(
346       makeCanonicalAffineApplies(b, loc, maps[1], allIvs));
347   SmallVector<Value, 8> oIdx(
348       makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
349 
350   IndexedValueType F(convOp.filter()), O(convOp.output());
351 
352   // Emit scalar form. Padded conv involves an affine.max in the memory access
353   // which is not allowed by affine.load. Override to use an StdIndexedValue
354   // when there is non-zero padding.
355   if (hasPadding(convOp)) {
356     Type type = convOp.input().getType().cast<MemRefType>().getElementType();
357     Value padValue = std_constant(type, getPadValueAttr<ConvOp>(type));
358     Value paddedInput = getPaddedInput<StdIndexedValue>(
359         convOp.input(), imIdx,
360         /* Only need to pad the window dimensions */
361         {0, static_cast<int>(imIdx.size()) - 1}, padValue);
362     O(oIdx) += F(fIdx) * paddedInput;
363   } else {
364     IndexedValueType I(convOp.input());
365     O(oIdx) += F(fIdx) * I(imIdx);
366   }
367 }
368 
369 template <typename PoolingOp>
370 static bool hasPadding(PoolingOp poolingOp) {
371   for (unsigned i = 0, e = poolingOp.getNumWindowLoops(); i < e; ++i) {
372     if (poolingOp.getLowPad(i) > 0 || poolingOp.getHighPad(i) > 0)
373       return true;
374   }
375   return false;
376 }
377 
378 template <typename IndexedValueType, typename PoolingOp>
379 static Value getPoolingInput(PoolingOp op, ArrayRef<Value> inputIndices) {
380   if (hasPadding(op)) {
381     Type type =
382         op.input().getType().template cast<MemRefType>().getElementType();
383     Value padValue = std_constant(type, getPadValueAttr<PoolingOp>(type));
384     return getPaddedInput<StdIndexedValue>(op.input(), inputIndices,
385                                            /*Pad every dimension*/ {},
386                                            padValue);
387   }
388   IndexedValueType input(op.input());
389   return input(inputIndices);
390 }
391 
392 template <typename IndexedValueType, typename OpType>
393 void emitPoolingMinMaxScalarImplementation(ArrayRef<Value> allIvs, OpType op) {
394   InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
395   // Emit scalar form.
396   IndexedValueType output(op.output());
397   Value lhs = output(indices.outputs);
398   Value rhs = getPoolingInput<IndexedValueType>(op, indices.inputs);
399   using edsc::op::sgt;
400   using edsc::op::slt;
401   Value value = std::is_same<OpType, PoolingMinOp>()
402                     ? std_select(slt(lhs, rhs), lhs, rhs)
403                     : std_select(sgt(lhs, rhs), lhs, rhs);
404   output(indices.outputs) = value;
405 }
406 
407 template <typename IndexedValueType>
408 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
409   emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMaxOp>(allIvs,
410                                                                         op);
411 }
412 
413 template <typename IndexedValueType>
414 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
415   emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMinOp>(allIvs,
416                                                                         op);
417 }
418 
419 template <typename IndexedValueType>
420 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
421   auto indices = getInputAndOutputIndices(allIvs, op);
422   IndexedValueType output(op.output());
423 
424   // Emit scalar form.
425   output(indices.outputs) +=
426       getPoolingInput<IndexedValueType>(op, indices.inputs);
427 }
428 
429 /// Emits the MLIR for the scalar part of the indexed generic op by:
430 ///   1. Emitting load ops for each input and output view in order. This is
431 ///      achieved by applying the appropriate input or output map to the
432 ///      enclosing induction variables.
433 ///   2. Emitting a call to `op.fun()` that takes as arguments the induction
434 ///      variables and the scalars from point 1. above.
435 ///   3. Emitting store ops to store the results of 2. to the output views.
436 ///
437 /// An example output may resemble:
438 ///
439 /// ```
440 ///    scf.for %i = %c0 to %0 step %c1 {
441 ///      scf.for %j = %c0 to %1 step %c1 {
442 ///        scf.for %k = %c0 to %4 step %c1 {
443 ///          %11 = load %arg0[%i, %j] :
444 ///            memref<?x?xf32, stride_specification>
445 ///          %12 = load %arg1[%i, %j, %k] :
446 ///            memref<?x?x?xf32, stride_specification>
447 ///          %13 = load %arg2[%i, %k, %j] :
448 ///            memref<?x?x?xf32, stride_specification>
449 ///          %14:2 = call @foo(%i, %j, %k, %11, %12, %13) :
450 ///            (index, index, index, f32, f32, f32) -> (f32, f32)
451 ///          store %14#0, %arg1[%i, %j, %k] :
452 ///            memref<?x?x?Xf32, stride_specification>
453 ///          store %14#1, %arg2[%i, %k, %j] :
454 ///            memref<?x?x?Xf32, stride_specification>
455 ///       }
456 ///      }
457 ///    }
458 /// ```
459 template <typename IndexedValueType>
460 static void emitScalarImplementation(ArrayRef<Value> allIvs,
461                                      IndexedGenericOp indexedGenericOp) {
462   assert(indexedGenericOp.hasBufferSemantics() &&
463          "expected linalg op with buffer semantics");
464   auto &b = ScopedContext::getBuilderRef();
465   auto loc = ScopedContext::getLocation();
466   unsigned nInputs = indexedGenericOp.getNumInputs();
467   unsigned nOutputs = indexedGenericOp.getNumOutputs();
468   unsigned nLoops = allIvs.size();
469   SmallVector<Value, 4> indexedValues;
470   indexedValues.reserve(nLoops + nInputs + nOutputs);
471   for (unsigned i = 0; i < nLoops; ++i)
472     indexedValues.push_back(allIvs[i]);
473 
474   // TODO: Avoid the loads if the corresponding argument of the
475   // region has no uses.
476   // 1.a. Emit load from input views.
477   for (unsigned i = 0; i < nInputs; ++i) {
478     auto indexing = makeCanonicalAffineApplies(
479         b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs);
480     // Pass input i through IndexedValueType emits the proper load operation.
481     indexedValues.push_back(
482         IndexedValueType(indexedGenericOp.getInput(i))(indexing));
483   }
484   // 1.b. Emit load from output views.
485   for (unsigned i = 0; i < nOutputs; ++i) {
486     auto indexing = makeCanonicalAffineApplies(
487         b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs);
488     // Pass output i through IndexedValueType emits the proper load operation.
489     indexedValues.push_back(
490         IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing));
491   }
492 
493   // TODO: When a region inliner exists, use it.
494   // 2. Inline region, currently only works for a single basic block.
495   // 3. Emit store.
496   SmallVector<SmallVector<Value, 8>, 8> indexing;
497   SmallVector<Value, 8> outputBuffers;
498   for (unsigned i = 0; i < nOutputs; ++i) {
499     indexing.push_back(makeCanonicalAffineApplies(
500         b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
501     outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i));
502   }
503   inlineRegionAndEmitStore<IndexedValueType>(indexedGenericOp, indexedValues,
504                                              indexing, outputBuffers);
505 }
506 
507 template <typename LoopTy>
508 static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
509                                                  OpBuilder &builder) {
510   using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
511 
512   ScopedContext scope(builder, op->getLoc());
513 
514   // The flattened loopToOperandRangesMaps is expected to be an invertible
515   // permutation map (which is asserted in the inverse calculation).
516   auto linalgOp = cast<LinalgOp>(op);
517   assert(linalgOp.hasBufferSemantics() &&
518          "expected linalg op with buffer semantics");
519   auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc());
520   SmallVector<Value, 4> allIvs;
521   GenerateLoopNest<LoopTy>::doit(
522       loopRanges, /*iterInitArgs*/ {}, linalgOp.iterator_types().getValue(),
523       [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
524         assert(iterArgs.empty() && "unexpected iterArgs");
525         allIvs.append(ivs.begin(), ivs.end());
526         llvm::TypeSwitch<Operation *>(op)
527             .Case<CopyOp, FillOp, ConvOp, PoolingMaxOp, PoolingMinOp,
528                   PoolingSumOp, IndexedGenericOp, LinalgOp>([&](auto op) {
529               emitScalarImplementation<IndexedValueTy>(allIvs, op);
530             })
531             .Default([&](Operation *op) { assert(false && "unexpected op"); });
532         return scf::ValueVector{};
533       });
534   // Number of loop ops might be different from the number of ivs since some
535   // loops like affine.parallel and scf.parallel have multiple ivs.
536   llvm::SetVector<Operation *> loopSet;
537   for (Value iv : allIvs) {
538     if (!iv)
539       return {};
540     // The induction variable is a block argument of the entry block of the
541     // loop operation.
542     BlockArgument ivVal = iv.dyn_cast<BlockArgument>();
543     if (!ivVal)
544       return {};
545     loopSet.insert(ivVal.getOwner()->getParentOp());
546   }
547   LinalgLoops loops(loopSet.begin(), loopSet.end());
548   return loops;
549 }
550 
551 namespace {
552 template <typename LoopType>
553 class LinalgRewritePattern : public RewritePattern {
554 public:
555   LinalgRewritePattern() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
556 
557   LogicalResult matchAndRewrite(Operation *op,
558                                 PatternRewriter &rewriter) const override {
559     if (!isa<LinalgOp>(op))
560       return failure();
561     if (!linalgOpToLoopsImpl<LoopType>(op, rewriter))
562       return failure();
563     rewriter.eraseOp(op);
564     return success();
565   }
566 };
567 
568 struct FoldAffineOp;
569 } // namespace
570 
571 template <typename LoopType>
572 static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
573   OwningRewritePatternList patterns;
574   patterns.insert<LinalgRewritePattern<LoopType>>();
575   DimOp::getCanonicalizationPatterns(patterns, context);
576   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
577   patterns.insert<FoldAffineOp>(context);
578   // Just apply the patterns greedily.
579   applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
580 }
581 
582 namespace {
583 /// Local folding pattern for AffineApplyOp that we can apply greedily.
584 /// This replaces AffineApplyOp by the proper value in cases where the
585 /// associated map is trivial.
586 /// A trivial map here is defined as a map with a single result and either:
587 ///   1. Zero operand + returns a single AffineConstantExpr
588 ///   2. One operand + returns a single AffineDimExpr
589 ///   3. One operand + returns a single AffineSymbolExpr
590 //
591 /// In the first case, the AffineApplyOp is replaced by a new constant. In the
592 /// other cases, it is replaced by its unique operand.
593 struct FoldAffineOp : public RewritePattern {
594   FoldAffineOp(MLIRContext *context)
595       : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {}
596 
597   LogicalResult matchAndRewrite(Operation *op,
598                                 PatternRewriter &rewriter) const override {
599     AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op);
600     auto map = affineApplyOp.getAffineMap();
601     if (map.getNumResults() != 1 || map.getNumInputs() > 1)
602       return failure();
603 
604     AffineExpr expr = map.getResult(0);
605     if (map.getNumInputs() == 0) {
606       if (auto val = expr.dyn_cast<AffineConstantExpr>()) {
607         rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue());
608         return success();
609       }
610       return failure();
611     }
612     if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) {
613       rewriter.replaceOp(op, op->getOperand(0));
614       return success();
615     }
616     return failure();
617   }
618 };
619 
620 struct LowerToAffineLoops
621     : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
622   void runOnFunction() override {
623     lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext());
624   }
625 };
626 
627 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
628   void runOnFunction() override {
629     lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext());
630   }
631 };
632 
633 struct LowerToParallelLoops
634     : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
635   void runOnFunction() override {
636     lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), &getContext());
637   }
638 };
639 } // namespace
640 
641 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() {
642   return std::make_unique<LowerToLoops>();
643 }
644 
645 std::unique_ptr<OperationPass<FuncOp>>
646 mlir::createConvertLinalgToParallelLoopsPass() {
647   return std::make_unique<LowerToParallelLoops>();
648 }
649 
650 std::unique_ptr<OperationPass<FuncOp>>
651 mlir::createConvertLinalgToAffineLoopsPass() {
652   return std::make_unique<LowerToAffineLoops>();
653 }
654 
655 /// Emits a loop nest with the proper body for `op`.
656 template <typename LoopTy>
657 Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
658                                                          Operation *op) {
659   return linalgOpToLoopsImpl<LoopTy>(op, builder);
660 }
661 
662 template Optional<LinalgLoops>
663 mlir::linalg::linalgLowerOpToLoops<AffineForOp>(OpBuilder &builder,
664                                                 Operation *op);
665 template Optional<LinalgLoops>
666 mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(OpBuilder &builder,
667                                                Operation *op);
668 template Optional<LinalgLoops>
669 mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(OpBuilder &builder,
670                                                     Operation *op);
671 
672 /// Emits a loop nest of `affine.for` with the proper body for `op`.
673 LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder,
674                                                   Operation *op) {
675   Optional<LinalgLoops> loops = linalgLowerOpToLoops<AffineForOp>(builder, op);
676   return loops ? success() : failure();
677 }
678 
679 /// Emits a loop nest of `scf.for` with the proper body for `op`.
680 LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) {
681   Optional<LinalgLoops> loops = linalgLowerOpToLoops<scf::ForOp>(builder, op);
682   return loops ? success() : failure();
683 }
684 
685 /// Emits a loop nest of `scf.parallel` with the proper body for `op`.
686 LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder,
687                                                     Operation *op) {
688   Optional<LinalgLoops> loops =
689       linalgLowerOpToLoops<scf::ParallelOp>(builder, op);
690   return loops ? success() : failure();
691 }
692