1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include <utility>
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 //===---------------------------------------------------------------------===//
32 // Methods and patterns that fuse elementwise `linalg.generic` operations.
33 //===---------------------------------------------------------------------===//
34 
35 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
36 /// the `producer` to use in the fused operation given the indexing map of the
37 /// result of the producer in the consumer.
38 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
39     OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
40     AffineMap fusedConsumerArgIndexMap) {
41   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
42   // from consumer loop -> consumer arg tensor index/producer result tensor
43   // index. The fused loop is same as the consumer loop. For each producer arg
44   // the indexing map to be computed is a map from consumer loop -> producer
45   // arg tensor index.
46   // producerResultIndexMap is a map from producer loop -> tensor index.
47   // Compute the inverse to get map from tensor index -> producer loop.
48   // The inverse is a map from producer result tensor index -> producer loop.
49   AffineMap invProducerResultIndexMap =
50       inversePermutation(producerResultIndexMap);
51   assert(invProducerResultIndexMap &&
52          "expected producer result indexing map to be invertible");
53 
54   LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
55   // argMap is a map from producer loop -> producer arg tensor index.
56   AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
57 
58   // Compose argMap with invProducerResultIndexMap to get a map from
59   // producer result tensor index -> producer arg tensor index.
60   AffineMap t1 = argMap.compose(invProducerResultIndexMap);
61 
62   // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
63   // consumer loop/ fused loop -> producer arg tensor index.
64   return t1.compose(fusedConsumerArgIndexMap);
65 }
66 
67 /// Conditions for elementwise fusion of generic operations.
68 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
69                                      OpOperand *consumerOpOperand) {
70   // Producer and consumer must have tensor semantics.
71   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
72     return false;
73 
74   // Verify that
75   // - the producer has all "parallel" iterator type.
76   if (producer.getNumParallelLoops() != producer.getNumLoops())
77     return false;
78 
79   // Only allow fusing the producer of an input operand for now.
80   // TODO: allow fusing the producer of an output operand.
81   if (!consumer.isInputTensor(consumerOpOperand))
82     return false;
83 
84   // Get the consumer index map. The number of results of the consumer index
85   // map must match the number of loops of the producer.
86   AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
87   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
88     return false;
89 
90   // Currently support only operations with single result.
91   if (producer.getNumOutputs() != 1)
92     return false;
93 
94   // Finally the index_map for the result must be invertible. For now just
95   // verify it is a permutation.
96   AffineMap producerResultIndexMap =
97       producer.getTiedIndexingMap(producer.getOutputOperand(0));
98   if (!producerResultIndexMap.isPermutation())
99     return false;
100 
101   // Ensure that the fusion does not remove size information required to
102   // get the loop bounds. For non-reduction generics, this is trivially the
103   // case due to the output operand. For reductions, we need to check that after
104   // the fusion, each loop dimension has at least one input that defines it.
105   if ((consumer.getNumReductionLoops())) {
106     BitVector coveredDims(consumer.getNumLoops(), false);
107 
108     auto addToCoveredDims = [&](AffineMap map) {
109       for (auto result : map.getResults())
110         if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
111           coveredDims[dimExpr.getPosition()] = true;
112     };
113 
114     for (auto pair :
115          llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) {
116       Value operand = std::get<0>(pair);
117       if (operand == consumerOpOperand->get())
118         continue;
119       AffineMap operandMap = std::get<1>(pair);
120       addToCoveredDims(operandMap);
121     }
122 
123     for (OpOperand *operand : producer.getInputOperands()) {
124       AffineMap newIndexingMap =
125           getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
126               operand, producerResultIndexMap, consumerIndexMap);
127       addToCoveredDims(newIndexingMap);
128     }
129     if (!coveredDims.all())
130       return false;
131   }
132 
133   return true;
134 }
135 
136 /// Generate the region of the fused tensor operation. The region of the fused
137 /// op must be empty.
138 static void
139 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
140                                  AffineMap consumerToProducerLoopsMap,
141                                  OpOperand *consumerOpOperand,
142                                  unsigned nloops) {
143   auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
144   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
145   // Build the region of the fused op.
146   Block &producerBlock = producer->getRegion(0).front();
147   Block &consumerBlock = consumer->getRegion(0).front();
148   Block *fusedBlock = new Block();
149   fusedOp.region().push_back(fusedBlock);
150   BlockAndValueMapping mapper;
151   OpBuilder::InsertionGuard guard(rewriter);
152   rewriter.setInsertionPointToStart(fusedBlock);
153 
154   // 2. Add an index operation for every fused loop dimension and use the
155   // `consumerToProducerLoopsMap` to map the producer indices.
156   if (producer.hasIndexSemantics()) {
157     // Add an index operation for every fused loop dimension.
158     unsigned numFusedOpLoops =
159         std::max(producer.getNumLoops(), consumer.getNumLoops());
160     SmallVector<Value> fusedIndices;
161     fusedIndices.reserve(numFusedOpLoops);
162     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
163                     std::back_inserter(fusedIndices), [&](uint64_t dim) {
164                       return rewriter.create<IndexOp>(producer.getLoc(), dim);
165                     });
166     for (IndexOp indexOp :
167          llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
168       Value newIndex = rewriter.create<mlir::AffineApplyOp>(
169           producer.getLoc(),
170           consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
171       mapper.map(indexOp.getResult(), newIndex);
172     }
173   }
174   // TODO: allow fusing the producer of an output operand.
175   assert(consumer.isInputTensor(consumerOpOperand) &&
176          "expected producer of input operand");
177   // 3. Consumer input operands up to consumerIdx (exclusive).
178   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
179            consumerOpOperand->getOperandNumber())) // input assumption.
180     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
181 
182   // Replacing consumerIdx requires getting the cloned, yielded, value from
183   // the (cloned) producer block. This happens in step 9.
184 
185   // 4. Splice in producer's input operands.
186   for (BlockArgument bbArg :
187        producerBlock.getArguments().take_front(producer.getNumInputs()))
188     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
189 
190   // 4.b. Producer output operand/map that is fused needs to be mapped to the
191   // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
192   assert(producer->getNumResults() == 1 && "expected single result producer");
193   if (producer.isInitTensor(producer.getOutputOperand(0))) {
194     BlockArgument bbArg = producerBlock.getArguments()
195                               .drop_front(producer.getNumInputs())
196                               // TODO: bbArg index of
197                               .front();
198     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
199   }
200   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
201   for (BlockArgument bbArg :
202        consumerBlock.getArguments()
203            .take_front(consumer.getNumInputs())
204            .drop_front(consumerOpOperand->getOperandNumber() + 1))
205     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
206   // 6. All of consumer's output operands.
207   for (BlockArgument bbArg :
208        consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
209     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
210   // 7. All of producer's output operands except the one fused.
211   // TODO: allow fusion of multi-result producers.
212   assert(producer->getNumResults() == 1 && "expected single result producer");
213 
214   // 8. Clone all producer operations except for the yield and index operations
215   // to the fused operation.
216   for (auto &op : producerBlock.without_terminator()) {
217     if (!isa<IndexOp>(op))
218       rewriter.clone(op, mapper);
219   }
220   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
221   // forward the yield operand.
222   auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
223   // TODO: allow fusion of multi-result producers.
224   assert(producer->getNumResults() == 1 && "expected single result producer");
225   unsigned producerResultNumber = 0;
226   Value replacement =
227       mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
228   // Sanity checks, if replacement is not already in the mapper then it must be
229   // produced outside.
230   if (replacement == yieldOp.getOperand(producerResultNumber)) {
231     if (auto bb = replacement.dyn_cast<BlockArgument>())
232       assert(bb.getOwner() != &producerBlock &&
233              "yielded block argument must have been mapped");
234     else
235       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
236              "yielded value must have been mapped");
237   }
238   mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
239              replacement);
240   // 10. Clone operations from the consumer to the fused op.
241   for (auto &op : consumerBlock.getOperations())
242     rewriter.clone(op, mapper);
243 
244   // Sanity checks.
245   assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
246          "Ill-formed GenericOp region");
247 }
248 
249 static Optional<SmallVector<Value>>
250 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
251                        const ControlFusionFn &controlFn,
252                        PatternRewriter &rewriter) {
253   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
254   if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
255       !controlFn(producer->getResult(0), *consumerOpOperand))
256     return llvm::None;
257 
258   // TODO: allow fusing the producer of an output operand.
259   assert(consumer.isInputTensor(consumerOpOperand) &&
260          "expected producer of input operand");
261 
262   // Compute the fused operands list and indexing maps.
263   SmallVector<Value> fusedOperands;
264   SmallVector<AffineMap> fusedIndexMaps;
265   fusedOperands.reserve(producer->getNumOperands() +
266                         consumer->getNumOperands());
267   fusedIndexMaps.reserve(producer->getNumOperands() +
268                          consumer->getNumOperands());
269   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
270   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
271   SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
272   SmallVector<OpOperand *>::iterator it =
273       llvm::find(consumerInputs, consumerOpOperand);
274   assert(it != consumerInputs.end() && "expected to find the consumer operand");
275   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
276     fusedOperands.push_back(opOperand->get());
277     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
278   }
279   // 4. Splice in producer's input operands/maps.
280   assert(producer->getNumResults() == 1 && "expected single result producer");
281   AffineMap producerResultIndexMap =
282       producer.getTiedIndexingMap(producer.getOutputOperand(0));
283   for (OpOperand *opOperand : producer.getInputOperands()) {
284     fusedOperands.push_back(opOperand->get());
285     // Compute indexing maps for the producer args in the fused operation.
286     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
287         opOperand, producerResultIndexMap,
288         consumer.getTiedIndexingMap(consumerOpOperand));
289     fusedIndexMaps.push_back(map);
290   }
291   // 4.b. Producer output operand/map that is fused needs to be passed if it is
292   // an "initTensor" (i.e. its value is actually read).
293   assert(producer->getNumResults() == 1 && "expected single result producer");
294   if (producer.isInitTensor(producer.getOutputOperand(0))) {
295     fusedOperands.push_back(producer.getOutputOperand(0)->get());
296     // Compute indexing maps for the producer args in the fused operation.
297     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
298         producer.getOutputOperand(0), producerResultIndexMap,
299         consumer.getTiedIndexingMap(consumerOpOperand));
300     fusedIndexMaps.push_back(map);
301   }
302   // 5. Remaining consumer's input operands/maps (drop past index
303   // `consumerIdx`).
304   for (OpOperand *opOperand :
305        llvm::make_range(std::next(it), consumerInputs.end())) {
306     fusedOperands.push_back(opOperand->get());
307     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
308   }
309   // 6. All of consumer's output operands (skip operands: added by the builder).
310   for (OpOperand *opOperand : consumer.getOutputOperands())
311     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
312   // 7. All of producer's output operands/maps except the one fused.
313   // TODO: allow fusion of multi-result producers.
314   assert(producer->getNumResults() == 1 && "expected single result producer");
315 
316   // Generate the fused op.
317   SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
318   auto fusedOp = rewriter.create<GenericOp>(
319       consumer.getLoc(), consumer->getResultTypes(),
320       /*inputs=*/fusedOperands,
321       // TODO: handle outputs.
322       consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
323       consumer.iterator_types(),
324       /*doc=*/nullptr,
325       /*library_call=*/nullptr);
326   if (!fusedOp.getShapesToLoopsMap()) {
327     // Fused op has invalid indexing maps. Typically this means something is off
328     // in the input, but going ahead here would result in verification errors.
329     // So cleanup and abort.
330     rewriter.eraseOp(fusedOp);
331     return llvm::None;
332   }
333 
334   // Construct an AffineMap from consumer loops to producer loops.
335   // consumer loop -> tensor index
336   AffineMap consumerResultIndexMap =
337       consumer.getTiedIndexingMap(consumerOpOperand);
338   // tensor index -> producer loop
339   AffineMap invProducerResultIndexMap =
340       inversePermutation(producerResultIndexMap);
341   assert(invProducerResultIndexMap &&
342          "expected producer result indexig map to be invertible");
343   // consumer loop -> producer loop
344   AffineMap consumerToProducerLoopsMap =
345       invProducerResultIndexMap.compose(consumerResultIndexMap);
346 
347   generateFusedElementwiseOpRegion(rewriter, fusedOp,
348                                    consumerToProducerLoopsMap,
349                                    consumerOpOperand, consumer.getNumLoops());
350   return SmallVector<Value>(fusedOp->getResults());
351 }
352 
353 static Optional<SmallVector<Value>>
354 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
355                    GenericOp producer, const ControlFusionFn &controlFn) {
356   if (producer->getNumResults() != 1)
357     return llvm::None;
358 
359   return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
360                                 rewriter);
361 }
362 
363 namespace {
364 /// Patterns to fuse a generic op, with the producer of its operands.
365 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
366 public:
367   FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
368                      PatternBenefit benefit = 1)
369       : OpRewritePattern<GenericOp>(context, benefit),
370         controlFn(std::move(fun)) {}
371 
372   LogicalResult matchAndRewrite(GenericOp genericOp,
373                                 PatternRewriter &rewriter) const override {
374     // Find the first operand that is defined by another generic op on tensors.
375     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
376       auto producer =
377           dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
378       if (!producer || !producer.hasTensorSemantics())
379         continue;
380       Optional<SmallVector<Value>> fusedOpResults =
381           fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
382       if (fusedOpResults) {
383         rewriter.replaceOp(genericOp, *fusedOpResults);
384         return success();
385       }
386     }
387     return failure();
388   }
389 
390 private:
391   ControlFusionFn controlFn;
392 };
393 } // namespace
394 
395 //===---------------------------------------------------------------------===//
396 // Methods and patterns that fuse reshape ops with elementwise operations by
397 // expanding the dimensionality of the elementwise operations.
398 //===---------------------------------------------------------------------===//
399 
400 /// Conditions for folding a generic operation with a reshape op by expanding
401 /// the iteration space dimensionality for tensor operations. These are
402 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
403 /// following fusion pattern.
404 ///
405 ///  Consider
406 ///
407 ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
408 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
409 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
410 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
411 ///  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
412 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
413 ///
414 ///  The reshape can be folded into the `genericOp` if its loop dimensionality
415 ///  is increased to match the result (operand) of the tensor_expand_shape.
416 ///  The indexing_map of the fused tensor in the `genericOp` and the
417 ///  reassociation map helps compute the indexing maps of the modified op.
418 ///  For the above example, based on the reassociation map it
419 ///  can be concluded that
420 ///
421 ///  - The loop used to access the first dimension of the fused tensor is split
422 ///    into two.
423 ///  - The loop used to access the second dimension of the fused tensor is kept
424 ///    as is.
425 ///  - The loop used to access the third dimension of the fused tensor is split
426 ///    into three.
427 ///
428 ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
429 ///  op, then
430 ///
431 ///   d0 -> e0, e1
432 ///   d1 -> e2, e3, e4
433 ///   d2 -> e5
434 ///
435 ///  substituting this, the generic op can be rewritten as
436 ///
437 ///  %d = linalg.generic ins(%0, %1 : )
438 ///        indexing_maps =
439 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
440 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
441 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
442 ///
443 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
444 ///  to make it consistent
445 ///
446 ///  %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
447 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
448 ///  %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
449 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
450 ///
451 ///  The added reshapes are again expanding patterns, so they will get fused
452 ///  with its producers if possible.
453 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
454                                                OpOperand *fusableOpOperand) {
455   // Is fusable only if:
456   // - All the indexing maps for operands and results are projected
457   //   permutations.
458   // - The fused tensor is not a scalar.
459   // - All the loops are parallel loops.
460   return genericOp.hasTensorSemantics() &&
461          llvm::all_of(genericOp.indexing_maps().getValue(),
462                       [](Attribute attr) {
463                         return attr.cast<AffineMapAttr>()
464                             .getValue()
465                             .isProjectedPermutation();
466                       }) &&
467          genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
468          llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
469            return attr.cast<StringAttr>().getValue() ==
470                   getParallelIteratorTypeName();
471          });
472 }
473 
474 namespace {
475 /// Information needed to expand a generic operation to fold the reshape with
476 /// it.
477 class ExpansionInfo {
478 public:
479   // Computes the mapping from original dimensions of the op to the dimensions
480   // of the expanded op given the `indexingMap` of the fused operand/result of
481   // the generic op, the `reassocationMaps` of the reshape op and the shape of
482   // the expanded op.
483   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
484                         ArrayRef<AffineMap> reassociationMaps,
485                         ArrayRef<int64_t> expandedShape,
486                         ArrayRef<int64_t> collapsedShape,
487                         PatternRewriter &rewriter);
488   unsigned getOrigOpNumDims() const { return reassociation.size(); }
489   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
490   ReassociationIndicesRef getExpandedDims(unsigned i) const {
491     return reassociation[i];
492   }
493   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
494     return expandedShapeMap[i];
495   }
496   ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
497 
498 private:
499   /// Reassociation from the dimensions in the original operation to the
500   /// dimension of the expanded operation.
501   SmallVector<ReassociationIndices> reassociation;
502   /// Mapping from extent of loops in the original operation, to the extent of
503   /// loops in the expanded operation.
504   SmallVector<SmallVector<int64_t>> expandedShapeMap;
505   /// Extent of the loop in the original operation.
506   SmallVector<int64_t> originalLoopExtent;
507   unsigned expandedOpNumDims;
508 };
509 } // namespace
510 
511 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
512                                      OpOperand *fusableOpOperand,
513                                      ArrayRef<AffineMap> reassociationMaps,
514                                      ArrayRef<int64_t> expandedShape,
515                                      ArrayRef<int64_t> collapsedShape,
516                                      PatternRewriter &rewriter) {
517   if (reassociationMaps.empty())
518     return failure();
519   AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
520 
521   Optional<SmallVector<int64_t, 4>> originalLoopRange =
522       linalgOp.getStaticLoopRanges();
523   if (!originalLoopRange)
524     return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
525   originalLoopExtent.assign(originalLoopRange->begin(),
526                             originalLoopRange->end());
527 
528   reassociation.clear();
529   expandedShapeMap.clear();
530   // Compute the number of dimension in the expanded op that correspond to each
531   // dimension of the original op.
532   SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
533   expandedShapeMap.resize(fusedIndexMap.getNumDims());
534   for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
535     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
536     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
537     numExpandedDims[pos] = foldedDims.getNumResults();
538     ArrayRef<int64_t> shape =
539         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
540     expandedShapeMap[pos].assign(shape.begin(), shape.end());
541   }
542   // The remaining dimensions remain the same.
543   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
544     if (expandedShapeMap[i].empty())
545       expandedShapeMap[i] = {originalLoopExtent[i]};
546 
547   // Compute reassociation map from the original op to the expanded op.
548   unsigned sum = 0;
549   reassociation.reserve(fusedIndexMap.getNumDims());
550   for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
551     auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
552     reassociation.emplace_back(seq.begin(), seq.end());
553     sum += numFoldedDim.value();
554   }
555   expandedOpNumDims = sum;
556   return success();
557 }
558 
559 /// Epanding the body of a linalg operation requires adaptations of the accessed
560 /// loop indices. Specifically, access of indices in the original operation need
561 /// to be replaced with linearizations of indices in the expanded op. That
562 /// requires the shape of the expanded dimensions to be static (at least all but
563 /// the most significant). For now check that these are all statically sized.
564 /// Note that this could be extended to handle dynamic case, but the
565 /// implementation below uses `affine.apply` which seems to have issues when the
566 /// shapes are not static.
567 static LogicalResult isGenericOpExpandable(GenericOp genericOp,
568                                            const ExpansionInfo &expansionInfo,
569                                            PatternRewriter &rewriter) {
570   if (!genericOp.hasIndexSemantics())
571     return success();
572   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
573     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
574     if (expandedShape.size() == 1)
575       continue;
576     for (int64_t shape : expandedShape.drop_front()) {
577       if (ShapedType::isDynamic(shape)) {
578         return rewriter.notifyMatchFailure(
579             genericOp, "cannot expand due to index semantics and dynamic dims");
580       }
581     }
582   }
583   return success();
584 }
585 
586 /// Return the indexing map to use in the expanded op for a given the
587 /// `indexingMap` of the original operation.
588 static AffineMap
589 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
590                            const ExpansionInfo &expansionInfo) {
591   SmallVector<AffineExpr> newExprs;
592   for (AffineExpr expr : indexingMap.getResults()) {
593     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
594     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
595         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
596           return builder.getAffineDimExpr(static_cast<unsigned>(v));
597         }));
598     newExprs.append(expandedExprs.begin(), expandedExprs.end());
599   }
600   return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
601                         indexingMap.getNumSymbols(), newExprs,
602                         builder.getContext());
603 }
604 
605 /// Return the type of the operand/result to use in the expanded op given the
606 /// type in the original op.
607 static RankedTensorType getExpandedType(RankedTensorType originalType,
608                                         AffineMap indexingMap,
609                                         const ExpansionInfo &expansionInfo) {
610   SmallVector<int64_t> expandedShape;
611   for (AffineExpr expr : indexingMap.getResults()) {
612     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
613     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
614     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
615   }
616   return RankedTensorType::get(expandedShape, originalType.getElementType());
617 }
618 
619 /// Returns the reassociation maps to use in the `tensor.expand_shape`
620 /// operation to convert the operands of the original operation to operands of
621 /// the expanded operation. The same method is used to compute the
622 /// `tensor.collapse_shape` used to collapse the result of the expanded
623 /// op to get the value that can replace all uses of the results of the original
624 /// op.
625 static SmallVector<ReassociationIndices>
626 getReassociationForExpansion(AffineMap indexingMap,
627                              const ExpansionInfo &expansionInfo) {
628   SmallVector<ReassociationIndices> reassociation;
629   unsigned numReshapeDims = 0;
630   for (AffineExpr expr : indexingMap.getResults()) {
631     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
632     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
633     SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
634         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
635     reassociation.emplace_back(std::move(indices));
636     numReshapeDims += numExpandedDims;
637   }
638   return reassociation;
639 }
640 
641 /// Update the body of an expanded linalg operation having index semantics. The
642 /// indices of the original operation need to be recovered by linearizing the
643 /// indices of the correspoding dimensions of the expanded operation. For now it
644 /// is assumed that the shapes of the expanded operation needed for
645 /// linearization are static.
646 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
647                                           Location loc, Region &fusedRegion,
648                                           const ExpansionInfo &expansionInfo) {
649   // Replace the original indices by the linearization of the expanded indices.
650   for (IndexOp indexOp :
651        llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
652     ArrayRef<int64_t> expandedDims =
653         expansionInfo.getExpandedDims(indexOp.dim());
654     assert(!expandedDims.empty() && "expected valid expansion info");
655 
656     // Skip index operations that are not affected by the expansion.
657     if (expandedDims.size() == 1 &&
658         expandedDims.front() == (int64_t)indexOp.dim())
659       continue;
660 
661     // Linearize the expanded indices of the original index dimension.
662     OpBuilder::InsertionGuard guard(rewriter);
663     rewriter.setInsertionPointAfter(indexOp);
664     ArrayRef<int64_t> expandedDimsShape =
665         expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
666     SmallVector<Value> expandedIndices;
667     expandedIndices.reserve(expandedDims.size() - 1);
668     llvm::transform(
669         expandedDims.drop_front(), std::back_inserter(expandedIndices),
670         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
671     Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
672     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
673       assert(!ShapedType::isDynamic(std::get<0>(it)));
674       AffineExpr idx, acc;
675       bindDims(rewriter.getContext(), idx, acc);
676       newIndex = rewriter.create<AffineApplyOp>(
677           indexOp.getLoc(), idx + acc * std::get<0>(it),
678           ValueRange{std::get<1>(it), newIndex});
679     }
680     rewriter.replaceOp(indexOp, newIndex);
681   }
682 }
683 
684 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op
685 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
686 /// that those conditions have been satisfied.
687 static Optional<SmallVector<Value>>
688 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
689                            OpOperand *fusableOpOperand,
690                            PatternRewriter &rewriter) {
691   assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
692          "preconditions for fuse operation failed");
693   // Check if reshape is expanding or collapsing.
694   auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
695   auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
696   bool isExpanding = (expandingReshapeOp != nullptr);
697   RankedTensorType expandedType = isExpanding
698                                       ? expandingReshapeOp.getResultType()
699                                       : collapsingReshapeOp.getSrcType();
700   RankedTensorType collapsedType = isExpanding
701                                        ? expandingReshapeOp.getSrcType()
702                                        : collapsingReshapeOp.getResultType();
703 
704   ExpansionInfo expansionInfo;
705   if (failed(expansionInfo.compute(
706           genericOp, fusableOpOperand,
707           isExpanding ? expandingReshapeOp.getReassociationMaps()
708                       : collapsingReshapeOp.getReassociationMaps(),
709           expandedType.getShape(), collapsedType.getShape(), rewriter)))
710     return llvm::None;
711 
712   if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
713     return llvm::None;
714 
715   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
716       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
717         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
718       }));
719 
720   SmallVector<Value> expandedOpOperands;
721   expandedOpOperands.reserve(genericOp.getNumInputs());
722   for (OpOperand *opOperand : genericOp.getInputOperands()) {
723     if (opOperand == fusableOpOperand) {
724       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
725                                                : collapsingReshapeOp.src());
726       continue;
727     }
728     if (genericOp.isInputTensor(opOperand)) {
729       AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
730       auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
731       RankedTensorType expandedOperandType =
732           getExpandedType(opOperandType, indexingMap, expansionInfo);
733       if (expandedOperandType != opOperand->get().getType()) {
734         // Reshape the operand to get the right type.
735         SmallVector<ReassociationIndices> reassociation =
736             getReassociationForExpansion(indexingMap, expansionInfo);
737         if (failed(reshapeLikeShapesAreCompatible(
738                 [&](const Twine &msg) {
739                   return rewriter.notifyMatchFailure(genericOp, msg);
740                 },
741                 opOperandType.getShape(), expandedOperandType.getShape(),
742                 reassociation,
743                 /*isExpandingReshape=*/true)))
744           return llvm::None;
745         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
746             genericOp.getLoc(), expandedOperandType, opOperand->get(),
747             reassociation));
748         continue;
749       }
750     }
751     expandedOpOperands.push_back(opOperand->get());
752   }
753 
754   Location loc = genericOp.getLoc();
755   SmallVector<Value> outputs;
756   for (OpOperand *opOperand : genericOp.getOutputOperands()) {
757     AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
758     auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
759     RankedTensorType expandedOutputType =
760         getExpandedType(opOperandType, indexingMap, expansionInfo);
761     if (expandedOutputType != opOperand->get().getType()) {
762       SmallVector<ReassociationIndices> reassociation =
763           getReassociationForExpansion(indexingMap, expansionInfo);
764       if (failed(reshapeLikeShapesAreCompatible(
765               [&](const Twine &msg) {
766                 return rewriter.notifyMatchFailure(genericOp, msg);
767               },
768               opOperandType.getShape(), expandedOutputType.getShape(),
769               reassociation,
770               /*isExpandingReshape=*/true)))
771         return llvm::None;
772       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
773           genericOp.getLoc(), expandedOutputType, opOperand->get(),
774           reassociation));
775     }
776   }
777 
778   // The iterator types of the expanded op are all parallel.
779   SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
780                                        getParallelIteratorTypeName());
781 
782   TypeRange resultTypes = ValueRange(outputs).getTypes();
783   auto fusedOp =
784       rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
785                                  /*inputs=*/expandedOpOperands, outputs,
786                                  expandedOpIndexingMaps, iteratorTypes);
787   Region &fusedRegion = fusedOp->getRegion(0);
788   Region &originalRegion = genericOp->getRegion(0);
789   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
790 
791   // Update the index accesses after the expansion.
792   updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
793 
794   // Reshape the result values to their original shape if this is a collapsing
795   // reshape folded into its consumer.
796   SmallVector<Value> resultVals;
797   for (OpResult opResult : genericOp->getOpResults()) {
798     int64_t resultNumber = opResult.getResultNumber();
799     if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
800       SmallVector<ReassociationIndices> reassociation =
801           getReassociationForExpansion(
802               genericOp.getTiedIndexingMap(
803                   genericOp.getOutputOperand(resultNumber)),
804               expansionInfo);
805       resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
806           genericOp.getLoc(), opResult.getType(),
807           fusedOp->getResult(resultNumber), reassociation));
808     } else {
809       resultVals.push_back(fusedOp->getResult(resultNumber));
810     }
811   }
812   // Assuming a single result.
813   return resultVals;
814 }
815 
816 namespace {
817 
818 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
819 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
820 /// in the consumer is expanded.
821 class FoldWithProducerReshapeOpByExpansion
822     : public OpRewritePattern<GenericOp> {
823 public:
824   FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
825                                        ControlFusionFn foldReshapes,
826                                        PatternBenefit benefit = 1)
827       : OpRewritePattern<GenericOp>(context, benefit),
828         controlFoldingReshapes(std::move(foldReshapes)) {}
829 
830   LogicalResult matchAndRewrite(GenericOp genericOp,
831                                 PatternRewriter &rewriter) const override {
832     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
833       tensor::CollapseShapeOp reshapeOp =
834           opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
835       if (!reshapeOp)
836         continue;
837       // Fold only if
838       // - The tensor reshape op is folding.
839       // - All constraints of fusing with reshape by expansion are met.
840       if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
841           (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
842         continue;
843 
844       Optional<SmallVector<Value>> replacementValues =
845           fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
846       if (!replacementValues)
847         return failure();
848       rewriter.replaceOp(genericOp, replacementValues.getValue());
849       return success();
850     }
851     return failure();
852   }
853 
854 private:
855   ControlFusionFn controlFoldingReshapes;
856 };
857 
858 /// Pattern to fold a tensor_expand_shape op with its producer generic op
859 /// by expanding the dimensionality of the loop in the producer op.
860 struct FoldReshapeWithGenericOpByExpansion
861     : public OpRewritePattern<tensor::ExpandShapeOp> {
862 
863   FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
864                                       ControlFusionFn foldReshapes,
865                                       PatternBenefit benefit = 1)
866       : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
867         controlFoldingReshapes(std::move(foldReshapes)) {}
868 
869   LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
870                                 PatternRewriter &rewriter) const override {
871     // Fold only if all constraints of fusing with reshape by expansion are met.
872     GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
873     if (!producer || producer.getNumOutputs() != 1 ||
874         !isFusableWithReshapeByDimExpansion(producer,
875                                             producer.getOutputOperand(0)) ||
876         !controlFoldingReshapes(producer->getResult(0),
877                                 reshapeOp->getOpOperand(0)))
878       return failure();
879     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
880         producer, reshapeOp, producer.getOutputOperand(0), rewriter);
881     if (!replacementValues)
882       return failure();
883     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
884     return success();
885   }
886 
887 private:
888   ControlFusionFn controlFoldingReshapes;
889 };
890 } // namespace
891 
892 //===---------------------------------------------------------------------===//
893 // Methods and patterns to fuse reshape with linalg.generic operations by
894 // contraction of dimensions.
895 //===---------------------------------------------------------------------===//
896 
897 /// For a given list of indices in the range of the `indexingMap` that are
898 /// folded, return the indices of the corresponding domain. Return `llvm::None`
899 /// on failure. Ensures that all the elements of the returned reassociation are
900 /// distinct.
901 static ReassociationIndices
902 getDomainReassociation(AffineMap indexingMap,
903                        ReassociationIndicesRef rangeReassociation) {
904   assert(indexingMap.isProjectedPermutation() &&
905          "expected projected permutation");
906 
907   ReassociationIndices domainReassociation = llvm::to_vector<4>(
908       llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
909         return indexingMap.getResults()[pos]
910             .cast<AffineDimExpr>()
911             .getPosition();
912       }));
913   // The projected permutation semantics ensures that there is no repetition of
914   // the domain indices.
915   return domainReassociation;
916 }
917 
918 /// For a given `dimSequence`, check if the sequence is conserved in the
919 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
920 /// Non-existence of the sequence returns true as well.
921 static bool isDimSequencePreserved(AffineMap indexingMap,
922                                    ReassociationIndicesRef dimSequence) {
923   assert(!dimSequence.empty() &&
924          "expected non-empty list for dimension sequence");
925   assert(indexingMap.isProjectedPermutation() &&
926          "expected indexing map to be projected permutation");
927 
928   llvm::SmallDenseSet<unsigned, 4> sequenceElements;
929   sequenceElements.insert(dimSequence.begin(), dimSequence.end());
930 
931   unsigned dimSequenceStart = dimSequence[0];
932   for (const auto &expr : enumerate(indexingMap.getResults())) {
933     unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition();
934     // 1.  Check if this start of the sequence.
935     if (dimInMapStart == dimSequenceStart) {
936       if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
937         return false;
938       // 1a. Check if sequence is preserved.
939       for (const auto &dimInSequence : enumerate(dimSequence)) {
940         unsigned dimInMap =
941             indexingMap.getResult(expr.index() + dimInSequence.index())
942                 .cast<AffineDimExpr>()
943                 .getPosition();
944         if (dimInMap != dimInSequence.value())
945           return false;
946       }
947       // Found the sequence. Projected permutation
948       // enforces that all AffineDimExprs in the result are unique, so no
949       // further checks are needed.
950       return true;
951     }
952     // 2. If position in the expr (which is of type AffineDimExpr) is part
953     // of sequence, return false here. This implies the entire sequence does not
954     // exist in the indexing map.
955     if (sequenceElements.count(dimInMapStart))
956       return false;
957   }
958   // 3. No element of sequence found. Return true.
959   return true;
960 }
961 
962 // Return the list of dimensions of the iteration domain that can be
963 // collapsed to allow for fusion with the a producer that is an expand_shape
964 // operation. If all dimensions created by expansion can be collapsed in the
965 // iteration space then the reshape is defunct.
966 //
967 // Example:
968 //
969 // ```mlir
970 // #map = affine_map<(d0, d1) -> (d0, d1)>
971 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
972 // %2 = linalg.init_tensor [..] : tensor<?x4xf32>
973 // %3 = linalg.generic {
974 //     indexing_maps = [#map, #map],
975 //     iterator_types = ["parallel" ,"parallel"]}
976 //     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
977 // ```
978 //
979 // can be fused by collapsing the dimensions of the iteration space.
980 //
981 // ```mlir
982 // #map = affine_map<(d0) -> (d0)>
983 // %2 = linalg.init_tensor [..] : tensor<?xf32>
984 // %3 = linalg.generic {
985 //     indexing_maps = [#map, #map],
986 //     iterator_types = ["parallel"]}
987 //     ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
988 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
989 // ```
990 //
991 // In the following example,
992 //
993 // ```mlir
994 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
995 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
996 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
997 // %2 = linalg.init_tensor [..] : tensor<4x?xf32>
998 // %2 = linalg.generic {
999 //     indexing_maps = [#map0, #map1],
1000 //     iterator_types = ["parallel" ,"parallel"]}
1001 //     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1002 // ```
1003 //
1004 // the reshape cannot be fused with the generic op by collapsing the op
1005 // dimensions since the indexing maps will have to contain mods and divs
1006 // to preserve the accesses pattern. When no dimensions of the iteration
1007 // space are collapsable and empty vector is returned.
1008 static SmallVector<ReassociationIndices>
1009 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1010                                  ArrayRef<ReassociationIndices> reassociation) {
1011   // Some basic checks for this fusion to be valid.
1012   if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1)
1013     return {};
1014 
1015   if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) {
1016         return map.isProjectedPermutation();
1017       })) {
1018     return {};
1019   }
1020 
1021   // Compute all the loops with the reduction iterator types.
1022   SmallVector<int64_t> reductionDims;
1023   for (const auto &iteratorType : llvm::enumerate(genericOp.iterator_types())) {
1024     if (isReductionIterator(iteratorType.value())) {
1025       reductionDims.push_back(iteratorType.index());
1026     }
1027   }
1028 
1029   llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1030   AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand);
1031   auto iteratorTypes = genericOp.iterator_types().getValue();
1032   SmallVector<ReassociationIndices> iterationSpaceReassociation;
1033   for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1034     assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1035 
1036     // Ignore dims that are not folded.
1037     if (foldedRangeDims.size() == 1)
1038       continue;
1039 
1040     ReassociationIndices foldedIterationSpaceDims =
1041         getDomainReassociation(indexingMap, foldedRangeDims);
1042 
1043     // Check that the folded iteration dims do not contain already processed
1044     // dims.
1045     if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1046           return processedIterationDims.count(dim);
1047         }))
1048       continue;
1049 
1050     // Check that all folded iterator types are all parallel or all reductions.
1051     Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
1052     if (!isParallelIterator(startIteratorType) &&
1053         !isReductionIterator(startIteratorType))
1054       continue;
1055     if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1056           return iteratorTypes[dim] != startIteratorType;
1057         }))
1058       continue;
1059 
1060     // If the folded dimensions correspond to a "reduction" iterator type,
1061     // the folded dimensions need to be "in-order". Strictly speaking this is
1062     // not necessary, for reductions that are associative and commutative,  but
1063     // using a more strict definition of reduction for now.
1064     if (isReductionIterator(startIteratorType)) {
1065       bool isContiguous = false;
1066       for (const auto &startDim : llvm::enumerate(reductionDims)) {
1067         // Move window in `reductionDims` to start of the folded iteration dims.
1068         if (startDim.value() != foldedIterationSpaceDims[0])
1069           continue;
1070         // If sizes doesnt match, trivial not contiguous. This condition should
1071         // not be hit.
1072         if (startDim.index() + foldedIterationSpaceDims.size() >
1073             reductionDims.size())
1074           break;
1075         // Check that the contiguity is maintained.
1076         isContiguous = true;
1077         for (const auto &foldedDim :
1078              llvm::enumerate(foldedIterationSpaceDims)) {
1079           if (reductionDims[foldedDim.index() + startDim.index()] !=
1080               foldedDim.value()) {
1081             isContiguous = false;
1082             break;
1083           }
1084         }
1085         break;
1086       }
1087       if (!isContiguous)
1088         continue;
1089     }
1090 
1091     // Check that the sequence is preserved in all indexing maps.
1092     if (llvm::any_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) {
1093           return !isDimSequencePreserved(indexingMap, foldedIterationSpaceDims);
1094         }))
1095       continue;
1096 
1097     processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1098                                   foldedIterationSpaceDims.end());
1099     iterationSpaceReassociation.emplace_back(
1100         std::move(foldedIterationSpaceDims));
1101   }
1102 
1103   return iterationSpaceReassociation;
1104 }
1105 
1106 /// Helper class to carry state while collapsing the `linalg.generic` op.
1107 namespace {
1108 class CollapsingInfo {
1109 public:
1110   LogicalResult initialize(unsigned origNumLoops,
1111                            ArrayRef<ReassociationIndices> foldedIterationDims) {
1112     llvm::SmallDenseSet<int64_t, 4> processedDims;
1113     // Find all the dims that are folded.
1114     for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1115       if (foldedIterationDim.empty())
1116         continue;
1117       // If the folded dims contain dims already folded, that's illegal
1118       // specification. Repetition within a list is also illegal.
1119       for (auto dim : foldedIterationDim) {
1120         if (dim >= origNumLoops)
1121           return failure();
1122         if (processedDims.count(dim))
1123           return failure();
1124         processedDims.insert(dim);
1125       }
1126       collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1127                                                    foldedIterationDim.end());
1128     }
1129     if (processedDims.size() > origNumLoops)
1130       return failure();
1131 
1132     // Add all the preserved dims of the original op as single
1133     // elements to `collapsedOpToOrigOpIterationDim`.
1134     for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1135       if (processedDims.count(dim))
1136         continue;
1137       collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1138     }
1139 
1140     llvm::sort(collapsedOpToOrigOpIterationDim,
1141                [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
1142                  return lhs[0] < rhs[0];
1143                });
1144     origOpToCollapsedOpIterationDim.resize(origNumLoops);
1145     for (const auto &foldedDims :
1146          llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1147       for (const auto &dim : enumerate(foldedDims.value()))
1148         origOpToCollapsedOpIterationDim[dim.value()] =
1149             std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1150     }
1151     return success();
1152   }
1153 
1154   /// Return mapping from collapsed loop domain to original loop domain.
1155   ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1156     return collapsedOpToOrigOpIterationDim;
1157   }
1158 
1159   /// Return mapping from original loop domain to collapsed loop domain. The
1160   /// mapping is a pair. First value is the dimension in the collapsed loop that
1161   /// the original loop is mapped to. Second is the relative position in folded
1162   /// list of this domain. For example if the original loop domain is 3D, and
1163   /// the collapsed loop domain is folding all of it, i.e.
1164   ///
1165   /// ```
1166   /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1167   /// ```
1168   ///
1169   /// then
1170   ///
1171   /// ```
1172   ///  origOpToCollapsedOpMapping[0] = {0, 0};
1173   ///  origOpToCollapsedOpMapping[1] = {0, 1};
1174   ///  origOpToCollapsedOpMapping[2] = {0, 2};
1175   ///  origOpToCollapsedOpMapping[3] = {1, 0};
1176   ///  origOpToCollapsedOpMapping[4] = {1, 1};
1177   /// ```
1178   ///
1179   ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1180     return origOpToCollapsedOpIterationDim;
1181   }
1182 
1183   /// Return the collapsed op iteration domain rank.
1184   unsigned getCollapsedOpIterationRank() const {
1185     return collapsedOpToOrigOpIterationDim.size();
1186   }
1187 
1188 private:
1189   /// Map from the iteration domain index in collapsed op to the iteration
1190   /// domain indices in the original op.
1191   SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1192 
1193   /// Map from iteration domain index in the original op to the iteration domain
1194   /// index in the collapsed op.
1195   SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1196 };
1197 } // namespace
1198 
1199 /// Get the iterator types for the collapsed operation given the original
1200 /// iterator types and collapsed dimensions.
1201 static SmallVector<StringRef>
1202 getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
1203                             const CollapsingInfo &collapsingInfo) {
1204   SmallVector<StringRef> collapsedIteratorTypes;
1205   for (ReassociationIndicesRef foldedIterDims :
1206        collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1207     assert(!foldedIterDims.empty() &&
1208            "reassociation indices expected to have non-empty sets");
1209     // Just pick the iterator type of the first folded dim. Pre-condition checks
1210     // expected to have checked that iterator types of all folded dimensions are
1211     // the same.
1212     collapsedIteratorTypes.push_back(
1213         iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue());
1214   }
1215   return collapsedIteratorTypes;
1216 }
1217 
1218 /// Compute the indexing map in the collapsed op that corresponds to the given
1219 /// `indexingMap` of the original operation.
1220 static AffineMap
1221 getCollapsedOpIndexingMap(AffineMap indexingMap,
1222                           const CollapsingInfo &collapsingInfo) {
1223   MLIRContext *context = indexingMap.getContext();
1224   assert(indexingMap.isProjectedPermutation() &&
1225          "expected indexing map to be projected permutation");
1226   SmallVector<AffineExpr> resultExprs;
1227   auto origOpToCollapsedOpMapping =
1228       collapsingInfo.getOrigOpToCollapsedOpMapping();
1229   for (auto expr : indexingMap.getResults()) {
1230     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
1231     // If the dim is not the first of the collapsed dim, do nothing.
1232     if (origOpToCollapsedOpMapping[dim].second != 0)
1233       continue;
1234     // The next n-dims are guaranteed to be collapsed. So just use the
1235     // iteration dimension of the collapsed op.
1236     resultExprs.push_back(
1237         getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1238   }
1239   return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1240                         resultExprs, context);
1241 }
1242 
1243 /// Return the `reassociation` indices to use to collapse the operand when the
1244 /// iteration space of a generic op is collapsed.
1245 static SmallVector<ReassociationIndices>
1246 getOperandReassociation(AffineMap indexingMap,
1247                         const CollapsingInfo &collapsingInfo) {
1248   unsigned counter = 0;
1249   SmallVector<ReassociationIndices> operandReassociation;
1250   auto origOpToCollapsedOpMapping =
1251       collapsingInfo.getOrigOpToCollapsedOpMapping();
1252   auto collapsedOpToOrigOpMapping =
1253       collapsingInfo.getCollapsedOpToOrigOpMapping();
1254   while (counter < indexingMap.getNumResults()) {
1255     unsigned dim =
1256         indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition();
1257     if (origOpToCollapsedOpMapping[dim].second == 0) {
1258       // This is the start of a collapsed dimensions of the iteration that
1259       // is gauranteed to be preserved in the indexing map. The number of folded
1260       // dims is obtained from the collapsed op to original op mapping.
1261       unsigned numFoldedDims =
1262           collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1263               .size();
1264       auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1265       operandReassociation.emplace_back(range.begin(), range.end());
1266       counter += numFoldedDims;
1267     }
1268   }
1269   return operandReassociation;
1270 }
1271 
1272 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1273 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
1274                                    OpOperand *opOperand,
1275                                    const CollapsingInfo &collapsingInfo,
1276                                    OpBuilder &builder) {
1277   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
1278   SmallVector<ReassociationIndices> operandReassociation =
1279       getOperandReassociation(indexingMap, collapsingInfo);
1280 
1281   // If the number of entries in the reassocation for the operand is same as the
1282   // number of results of the indexing map, then nothing to do for this operand.
1283   Value operand = opOperand->get();
1284   if (operandReassociation.size() == indexingMap.getNumResults())
1285     return operand;
1286 
1287   // Insert a reshape to collapse the dimensions.
1288   auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
1289       loc, operand, operandReassociation);
1290   return reshapeOp.getResult();
1291 }
1292 
1293 /// Modify the `linalg.index` operations in the original generic op, to its
1294 /// value in the collapsed operation.
1295 void generateCollapsedIndexingRegion(Location loc, Block *block,
1296                                      const CollapsingInfo &collapsingInfo,
1297                                      ValueRange loopRange,
1298                                      PatternRewriter &rewriter) {
1299   OpBuilder::InsertionGuard g(rewriter);
1300   rewriter.setInsertionPointToStart(block);
1301 
1302   // Collect all the original index ops.
1303   auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1304 
1305   // For each folded dimension list resolve the original induction variable
1306   // values in terms of the folded dimension induction variable.
1307   //   i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1308   // can be inverted to
1309   //   i2 = i_{folded} % d2
1310   //   i1 = (i_{folded} / d2) % d1
1311   //   i0 = i_{folded} / (d1 * d2)
1312   llvm::DenseMap<unsigned, Value> indexReplacementVals;
1313   for (auto &foldedDims :
1314        enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1315     ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1316     Value newIndexVal =
1317         rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1318     for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1319       indexReplacementVals[dim] =
1320           rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1321       newIndexVal =
1322           rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1323     }
1324     indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1325   }
1326 
1327   for (auto indexOp : indexOps) {
1328     auto dim = indexOp.dim();
1329     rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1330   }
1331 }
1332 
1333 /// Implementation of fusion with reshape operation by collapsing dimensions.
1334 static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
1335     GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
1336     OpOperand *fusableOpOperand, PatternRewriter &rewriter) {
1337   // Bail on trivial no-op cases.
1338   if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1339       llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1340         return foldedDims.size() <= 1;
1341       }))
1342     return failure();
1343 
1344   CollapsingInfo collapsingInfo;
1345   if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
1346                                        foldedIterationDims))) {
1347     return rewriter.notifyMatchFailure(
1348         genericOp, "illegal to collapse specified dimensions");
1349   }
1350 
1351   // Get the iterator types for the operand.
1352   SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes(
1353       genericOp.iterator_types().getValue(), collapsingInfo);
1354 
1355   // Get the indexing maps.
1356   auto indexingMaps = llvm::to_vector(
1357       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap map) {
1358         return getCollapsedOpIndexingMap(map, collapsingInfo);
1359       }));
1360 
1361   Location loc = genericOp->getLoc();
1362 
1363   // Get the input operands.
1364   auto inputOperands = llvm::to_vector(
1365       llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) {
1366         return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
1367                                      rewriter);
1368       }));
1369 
1370   // Get the output operands and result types.
1371   SmallVector<Type> resultTypes;
1372   SmallVector<Value> outputOperands;
1373   resultTypes.reserve(genericOp.getNumOutputs());
1374   outputOperands.reserve(genericOp.getNumOutputs());
1375   for (OpOperand *output : genericOp.getOutputOperands()) {
1376     Value newOutput =
1377         getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter);
1378     outputOperands.push_back(newOutput);
1379     resultTypes.push_back(newOutput.getType());
1380   }
1381 
1382   // Create the generic op.
1383   auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
1384       loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1385       iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1386   Block *origOpBlock = &genericOp->getRegion(0).front();
1387   Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
1388   rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1389                        collapsedOpBlock->getArguments());
1390 
1391   if (collapsedGenericOp.hasIndexSemantics()) {
1392     // Collect the loop range of the generic op.
1393     OpBuilder::InsertionGuard g(rewriter);
1394     rewriter.setInsertionPoint(collapsedGenericOp);
1395     SmallVector<Range> loopRanges =
1396         cast<LinalgOp>(genericOp.getOperation())
1397             .createLoopRanges(rewriter, genericOp.getLoc());
1398     assert(llvm::all_of(loopRanges,
1399                         [](Range range) {
1400                           return matchPattern(range.offset, m_Zero()) &&
1401                                  matchPattern(range.stride, m_One());
1402                         }) &&
1403            "expected all loop ranges to have zero start and unit stride");
1404     SmallVector<Value> loopBound = llvm::to_vector(
1405         llvm::map_range(loopRanges, [](Range range) { return range.size; }));
1406     generateCollapsedIndexingRegion(loc,
1407                                     &collapsedGenericOp->getRegion(0).front(),
1408                                     collapsingInfo, loopBound, rewriter);
1409   }
1410 
1411   // Insert expanding reshape for the result to get back the original result
1412   // type.
1413   SmallVector<Value> results;
1414   for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
1415     Value collapsedOpResult =
1416         collapsedGenericOp->getResult(originalResult.index());
1417     auto originalResultType =
1418         originalResult.value().getType().cast<ShapedType>();
1419     auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
1420     if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1421       AffineMap indexingMap =
1422           genericOp.getTiedIndexingMapForResult(originalResult.value());
1423       SmallVector<ReassociationIndices> reassociation =
1424           getOperandReassociation(indexingMap, collapsingInfo);
1425       Value result = rewriter.create<tensor::ExpandShapeOp>(
1426           loc, originalResultType, collapsedOpResult, reassociation);
1427       results.push_back(result);
1428     } else {
1429       results.push_back(collapsedOpResult);
1430     }
1431   }
1432   return results;
1433 }
1434 
1435 namespace {
1436 
1437 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1438 /// contracting dimensions of the loop.
1439 class FoldWithProducerReshapeOpByCollapsing
1440     : public OpRewritePattern<GenericOp> {
1441 public:
1442   FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1443                                         ControlFusionFn foldReshapes,
1444                                         PatternBenefit benefit = 1)
1445       : OpRewritePattern<GenericOp>(context, benefit),
1446         controlFoldingReshapes(std::move(foldReshapes)) {}
1447 
1448   LogicalResult matchAndRewrite(GenericOp genericOp,
1449                                 PatternRewriter &rewriter) const override {
1450     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1451       tensor::ExpandShapeOp reshapeOp =
1452           opOperand->get().getDefiningOp<tensor::ExpandShapeOp>();
1453       if (!reshapeOp)
1454         continue;
1455 
1456       SmallVector<ReassociationIndices> collapsableIterationDims =
1457           getCollapsableIterationSpaceDims(genericOp, opOperand,
1458                                            reshapeOp.getReassociationIndices());
1459       if (collapsableIterationDims.empty() ||
1460           !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
1461         continue;
1462       }
1463 
1464       Optional<SmallVector<Value>> replacements =
1465           collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
1466                                          opOperand, rewriter);
1467       if (!replacements) {
1468         return rewriter.notifyMatchFailure(
1469             genericOp, "failed to do the fusion by collapsing transformation");
1470       }
1471 
1472       rewriter.replaceOp(genericOp, replacements.getValue());
1473       return success();
1474     }
1475     return failure();
1476   }
1477 
1478 private:
1479   ControlFusionFn controlFoldingReshapes;
1480 };
1481 } // namespace
1482 
1483 //===---------------------------------------------------------------------===//
1484 // Methods and patterns that fuse constants with linalg.generic operations.
1485 //===---------------------------------------------------------------------===//
1486 
1487 namespace {
1488 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1489 /// handle cases where the constant is not single-valued.
1490 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1491 public:
1492   FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
1493       : OpRewritePattern<GenericOp>(context, benefit) {}
1494 
1495   LogicalResult matchAndRewrite(GenericOp genericOp,
1496                                 PatternRewriter &rewriter) const override {
1497     if (!genericOp.hasTensorSemantics())
1498       return failure();
1499     for (OpOperand *opOperand : genericOp.getInputOperands()) {
1500       Operation *def = opOperand->get().getDefiningOp();
1501       Attribute constantAttr;
1502       auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1503         {
1504           DenseElementsAttr splatAttr;
1505           if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1506               splatAttr.isSplat() &&
1507               splatAttr.getType().getElementType().isIntOrFloat()) {
1508             constantAttr = splatAttr.getSplatValue<Attribute>();
1509             return true;
1510           }
1511         }
1512         {
1513           IntegerAttr intAttr;
1514           if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1515             constantAttr = intAttr;
1516             return true;
1517           }
1518         }
1519         {
1520           FloatAttr floatAttr;
1521           if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1522             constantAttr = floatAttr;
1523             return true;
1524           }
1525         }
1526         return false;
1527       };
1528 
1529       auto resultValue = opOperand->get().dyn_cast<OpResult>();
1530       if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1531         continue;
1532 
1533       // The operands and the indexing_maps of the fused operation the same as
1534       // the operands and indexing_maps of the generic operations with the
1535       // values at the constant index dropped.
1536       SmallVector<AffineMap> fusedIndexMaps;
1537       SmallVector<Value> fusedOperands;
1538       SmallVector<Location> fusedLocs{genericOp.getLoc()};
1539       fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1540       fusedOperands.reserve(genericOp.getNumInputs());
1541       fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
1542       for (OpOperand *inputOperand : genericOp.getInputOperands()) {
1543         if (inputOperand == opOperand)
1544           continue;
1545         Value inputValue = inputOperand->get();
1546         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1547         fusedOperands.push_back(inputValue);
1548         fusedLocs.push_back(inputValue.getLoc());
1549       }
1550       for (OpOperand *outputOperand : genericOp.getOutputOperands())
1551         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1552 
1553       // Check if the operation shapes to loops map is computable.
1554       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1555         return rewriter.notifyMatchFailure(
1556             genericOp, "fused op loop bound computation failed");
1557       }
1558 
1559       // Create a constant scalar value from the splat constant.
1560       Value scalarConstant = rewriter.create<arith::ConstantOp>(
1561           def->getLoc(), constantAttr, constantAttr.getType());
1562 
1563       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1564       auto fusedOp = rewriter.create<GenericOp>(
1565           rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1566           /*inputs=*/fusedOperands,
1567           /*outputs=*/outputOperands,
1568           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1569           genericOp.iterator_types(),
1570           /*doc=*/nullptr,
1571           /*library_call=*/nullptr);
1572 
1573       // Map the block argument corresponding to the replaced argument with the
1574       // scalar constant.
1575       Region &region = genericOp->getRegion(0);
1576       Block &entryBlock = *region.begin();
1577       BlockAndValueMapping mapping;
1578       mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1579                   scalarConstant);
1580       Region &fusedRegion = fusedOp->getRegion(0);
1581       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1582                                  mapping);
1583       rewriter.replaceOp(genericOp, fusedOp->getResults());
1584       return success();
1585     }
1586     return failure();
1587   }
1588 };
1589 
1590 } // namespace
1591 
1592 //===---------------------------------------------------------------------===//
1593 // Miscellaneous patterns that help fusion.
1594 //===---------------------------------------------------------------------===//
1595 
1596 namespace {
1597 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
1598 /// the value of the `outs` operand is not used within the op.  This is only
1599 /// implemented for `linalg.generic` operations for now, but should hold for all
1600 /// linalg structured ops.
1601 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1602   using OpRewritePattern<GenericOp>::OpRewritePattern;
1603 
1604   LogicalResult matchAndRewrite(GenericOp op,
1605                                 PatternRewriter &rewriter) const override {
1606     rewriter.startRootUpdate(op);
1607     bool modifiedOutput = false;
1608     Location loc = op.getLoc();
1609     for (OpOperand *opOperand : op.getOutputOperands()) {
1610       if (!op.payloadUsesValueFromOperand(opOperand)) {
1611         Value operandVal = opOperand->get();
1612         auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
1613         if (!operandType)
1614           continue;
1615 
1616         // If outs is sparse, leave it to the sparse compiler.
1617         if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
1618           continue;
1619 
1620         // If outs is already an `init_tensor` operation, nothing to do.
1621         auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
1622         if (definingOp)
1623           continue;
1624         modifiedOutput = true;
1625         SmallVector<Value> dynamicDims;
1626         for (const auto &dim : llvm::enumerate(operandType.getShape())) {
1627           if (dim.value() != ShapedType::kDynamicSize)
1628             continue;
1629           dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
1630               loc, operandVal, dim.index()));
1631         }
1632         Value initTensor = rewriter.create<InitTensorOp>(
1633             loc, dynamicDims, operandType.getShape(),
1634             operandType.getElementType());
1635         op->setOperand(opOperand->getOperandNumber(), initTensor);
1636       }
1637     }
1638     if (!modifiedOutput) {
1639       rewriter.cancelRootUpdate(op);
1640       return failure();
1641     }
1642     rewriter.finalizeRootUpdate(op);
1643     return success();
1644   }
1645 };
1646 
1647 /// Fold linalg.fill into linalg.generic
1648 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
1649   using OpRewritePattern<GenericOp>::OpRewritePattern;
1650 
1651   LogicalResult matchAndRewrite(GenericOp genericOp,
1652                                 PatternRewriter &rewriter) const override {
1653     if (!genericOp.hasTensorSemantics())
1654       return failure();
1655     bool fillFound = false;
1656     Block &payload = genericOp.region().front();
1657     for (OpOperand *opOperand : genericOp.getInputOperands()) {
1658       if (!genericOp.payloadUsesValueFromOperand(opOperand))
1659         continue;
1660       FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
1661       if (!fillOp)
1662         continue;
1663       fillFound = true;
1664       payload.getArgument(opOperand->getOperandNumber())
1665           .replaceAllUsesWith(fillOp.value());
1666     }
1667     return success(fillFound);
1668   }
1669 };
1670 } // namespace
1671 
1672 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
1673     RewritePatternSet &patterns,
1674     const ControlFusionFn &controlFoldingReshapes) {
1675   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
1676                                                     controlFoldingReshapes);
1677   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
1678                                                      controlFoldingReshapes);
1679 }
1680 
1681 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
1682     RewritePatternSet &patterns,
1683     const ControlFusionFn &controlFoldingReshapes) {
1684   patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
1685                                                       controlFoldingReshapes);
1686 }
1687 
1688 void mlir::linalg::populateElementwiseOpsFusionPatterns(
1689     RewritePatternSet &patterns,
1690     const ControlFusionFn &controlElementwiseOpsFusion) {
1691   auto *context = patterns.getContext();
1692   patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
1693   patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
1694                RemoveOutsDependency>(context);
1695 }
1696 
1697 //===---------------------------------------------------------------------===//
1698 // Passes
1699 //===---------------------------------------------------------------------===//
1700 
1701 namespace {
1702 
1703 /// Pass that fuses generic ops on tensors. Used only for testing.
1704 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
1705 // patterns added here heavily depends on the cost function used. Having an
1706 // opinionated pass of this form is not recommended. Deprecate this pass in
1707 // favor of test passes that check the functionality of each of the patterns
1708 // added here individually.
1709 struct LinalgElementwiseOpFusionPass
1710     : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
1711   void runOnOperation() override {
1712     Operation *op = getOperation();
1713     MLIRContext *context = op->getContext();
1714     RewritePatternSet patterns(context);
1715 
1716     // Add folding with reshape by expansion patterns.
1717     ControlFusionFn defaultControlFn = [](const OpResult &producer,
1718                                           const OpOperand &consumer) {
1719       return producer.hasOneUse();
1720     };
1721 
1722     // Add elementwise op fusion patterns.
1723     populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
1724 
1725     populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
1726 
1727     // Add the sparse tensor rewriting patterns.
1728     populateSparseTensorRewriting(patterns);
1729 
1730     // General canonicalization patterns.
1731     AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1732     GenericOp::getCanonicalizationPatterns(patterns, context);
1733     tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1734     tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1735     context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1736         patterns);
1737 
1738     // Add constant folding patterns.
1739     populateConstantFoldLinalgOperations(patterns, defaultControlFn);
1740 
1741     // Use TopDownTraversal for compile time reasons
1742     GreedyRewriteConfig grc;
1743     grc.useTopDownTraversal = true;
1744     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
1745                                        grc);
1746   }
1747 };
1748 
1749 } // namespace
1750 
1751 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
1752   return std::make_unique<LinalgElementwiseOpFusionPass>();
1753 }
1754