1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include <utility>
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 //===---------------------------------------------------------------------===//
32 // Methods and patterns that fuse elementwise `linalg.generic` operations.
33 //===---------------------------------------------------------------------===//
34 
35 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
36 /// the `producer` to use in the fused operation given the indexing map of the
37 /// result of the producer in the consumer.
38 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
39     OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
40     AffineMap fusedConsumerArgIndexMap) {
41   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
42   // from consumer loop -> consumer arg tensor index/producer result tensor
43   // index. The fused loop is same as the consumer loop. For each producer arg
44   // the indexing map to be computed is a map from consumer loop -> producer
45   // arg tensor index.
46   // producerResultIndexMap is a map from producer loop -> tensor index.
47   // Compute the inverse to get map from tensor index -> producer loop.
48   // The inverse is a map from producer result tensor index -> producer loop.
49   AffineMap invProducerResultIndexMap =
50       inversePermutation(producerResultIndexMap);
51   assert(invProducerResultIndexMap &&
52          "expected producer result indexing map to be invertible");
53 
54   LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
55   // argMap is a map from producer loop -> producer arg tensor index.
56   AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
57 
58   // Compose argMap with invProducerResultIndexMap to get a map from
59   // producer result tensor index -> producer arg tensor index.
60   AffineMap t1 = argMap.compose(invProducerResultIndexMap);
61 
62   // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
63   // consumer loop/ fused loop -> producer arg tensor index.
64   return t1.compose(fusedConsumerArgIndexMap);
65 }
66 
67 /// Conditions for elementwise fusion of generic operations.
68 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
69                                      OpOperand *consumerOpOperand) {
70   // Producer and consumer must have tensor semantics.
71   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
72     return false;
73 
74   // Verify that
75   // - the producer has all "parallel" iterator type.
76   if (producer.getNumParallelLoops() != producer.getNumLoops())
77     return false;
78 
79   // Only allow fusing the producer of an input operand for now.
80   // TODO: allow fusing the producer of an output operand.
81   if (!consumer.isInputTensor(consumerOpOperand))
82     return false;
83 
84   // Get the consumer index map. The number of results of the consumer index
85   // map must match the number of loops of the producer.
86   AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
87   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
88     return false;
89 
90   // Currently support only operations with single result.
91   if (producer.getNumOutputs() != 1)
92     return false;
93 
94   // Finally the index_map for the result must be invertible. For now just
95   // verify it is a permutation.
96   AffineMap producerResultIndexMap =
97       producer.getTiedIndexingMap(producer.getOutputOperand(0));
98   if (!producerResultIndexMap.isPermutation())
99     return false;
100 
101   // Ensure that the fusion does not remove size information required to
102   // get the loop bounds. For non-reduction generics, this is trivially the
103   // case due to the output operand. For reductions, we need to check that after
104   // the fusion, each loop dimension has at least one input that defines it.
105   if ((consumer.getNumReductionLoops())) {
106     BitVector coveredDims(consumer.getNumLoops(), false);
107 
108     auto addToCoveredDims = [&](AffineMap map) {
109       for (auto result : map.getResults())
110         if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
111           coveredDims[dimExpr.getPosition()] = true;
112     };
113 
114     for (auto pair :
115          llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) {
116       Value operand = std::get<0>(pair);
117       if (operand == consumerOpOperand->get())
118         continue;
119       AffineMap operandMap = std::get<1>(pair);
120       addToCoveredDims(operandMap);
121     }
122 
123     for (OpOperand *operand : producer.getInputOperands()) {
124       AffineMap newIndexingMap =
125           getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
126               operand, producerResultIndexMap, consumerIndexMap);
127       addToCoveredDims(newIndexingMap);
128     }
129     if (!coveredDims.all())
130       return false;
131   }
132 
133   return true;
134 }
135 
136 /// Generate the region of the fused tensor operation. The region of the fused
137 /// op must be empty.
138 static void
139 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
140                                  AffineMap consumerToProducerLoopsMap,
141                                  OpOperand *consumerOpOperand,
142                                  unsigned nloops) {
143   auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
144   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
145   // Build the region of the fused op.
146   Block &producerBlock = producer->getRegion(0).front();
147   Block &consumerBlock = consumer->getRegion(0).front();
148   Block *fusedBlock = new Block();
149   fusedOp.region().push_back(fusedBlock);
150   BlockAndValueMapping mapper;
151   OpBuilder::InsertionGuard guard(rewriter);
152   rewriter.setInsertionPointToStart(fusedBlock);
153 
154   // 2. Add an index operation for every fused loop dimension and use the
155   // `consumerToProducerLoopsMap` to map the producer indices.
156   if (producer.hasIndexSemantics()) {
157     // Add an index operation for every fused loop dimension.
158     unsigned numFusedOpLoops =
159         std::max(producer.getNumLoops(), consumer.getNumLoops());
160     SmallVector<Value> fusedIndices;
161     fusedIndices.reserve(numFusedOpLoops);
162     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
163                     std::back_inserter(fusedIndices), [&](uint64_t dim) {
164                       return rewriter.create<IndexOp>(producer.getLoc(), dim);
165                     });
166     for (IndexOp indexOp :
167          llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
168       Value newIndex = rewriter.create<mlir::AffineApplyOp>(
169           producer.getLoc(),
170           consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
171       mapper.map(indexOp.getResult(), newIndex);
172     }
173   }
174   // TODO: allow fusing the producer of an output operand.
175   assert(consumer.isInputTensor(consumerOpOperand) &&
176          "expected producer of input operand");
177   // 3. Consumer input operands up to consumerIdx (exclusive).
178   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
179            consumerOpOperand->getOperandNumber())) // input assumption.
180     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
181 
182   // Replacing consumerIdx requires getting the cloned, yielded, value from
183   // the (cloned) producer block. This happens in step 9.
184 
185   // 4. Splice in producer's input operands.
186   for (BlockArgument bbArg :
187        producerBlock.getArguments().take_front(producer.getNumInputs()))
188     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
189 
190   // 4.b. Producer output operand/map that is fused needs to be mapped to the
191   // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
192   assert(producer->getNumResults() == 1 && "expected single result producer");
193   if (producer.isInitTensor(producer.getOutputOperand(0))) {
194     BlockArgument bbArg = producerBlock.getArguments()
195                               .drop_front(producer.getNumInputs())
196                               // TODO: bbArg index of
197                               .front();
198     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
199   }
200   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
201   for (BlockArgument bbArg :
202        consumerBlock.getArguments()
203            .take_front(consumer.getNumInputs())
204            .drop_front(consumerOpOperand->getOperandNumber() + 1))
205     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
206   // 6. All of consumer's output operands.
207   for (BlockArgument bbArg :
208        consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
209     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
210   // 7. All of producer's output operands except the one fused.
211   // TODO: allow fusion of multi-result producers.
212   assert(producer->getNumResults() == 1 && "expected single result producer");
213 
214   // 8. Clone all producer operations except for the yield and index operations
215   // to the fused operation.
216   for (auto &op : producerBlock.without_terminator()) {
217     if (!isa<IndexOp>(op))
218       rewriter.clone(op, mapper);
219   }
220   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
221   // forward the yield operand.
222   auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
223   // TODO: allow fusion of multi-result producers.
224   assert(producer->getNumResults() == 1 && "expected single result producer");
225   unsigned producerResultNumber = 0;
226   Value replacement =
227       mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
228   // Sanity checks, if replacement is not already in the mapper then it must be
229   // produced outside.
230   if (replacement == yieldOp.getOperand(producerResultNumber)) {
231     if (auto bb = replacement.dyn_cast<BlockArgument>())
232       assert(bb.getOwner() != &producerBlock &&
233              "yielded block argument must have been mapped");
234     else
235       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
236              "yielded value must have been mapped");
237   }
238   mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
239              replacement);
240   // 10. Clone operations from the consumer to the fused op.
241   for (auto &op : consumerBlock.getOperations())
242     rewriter.clone(op, mapper);
243 
244   // Sanity checks.
245   assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
246          "Ill-formed GenericOp region");
247 }
248 
249 static Optional<SmallVector<Value>>
250 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
251                        const ControlFusionFn &controlFn,
252                        PatternRewriter &rewriter) {
253   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
254   if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
255       !controlFn(producer->getResult(0), *consumerOpOperand))
256     return llvm::None;
257 
258   // TODO: allow fusing the producer of an output operand.
259   assert(consumer.isInputTensor(consumerOpOperand) &&
260          "expected producer of input operand");
261 
262   // Compute the fused operands list and indexing maps.
263   SmallVector<Value> fusedOperands;
264   SmallVector<AffineMap> fusedIndexMaps;
265   fusedOperands.reserve(producer->getNumOperands() +
266                         consumer->getNumOperands());
267   fusedIndexMaps.reserve(producer->getNumOperands() +
268                          consumer->getNumOperands());
269   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
270   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
271   SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
272   SmallVector<OpOperand *>::iterator it =
273       llvm::find(consumerInputs, consumerOpOperand);
274   assert(it != consumerInputs.end() && "expected to find the consumer operand");
275   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
276     fusedOperands.push_back(opOperand->get());
277     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
278   }
279   // 4. Splice in producer's input operands/maps.
280   assert(producer->getNumResults() == 1 && "expected single result producer");
281   AffineMap producerResultIndexMap =
282       producer.getTiedIndexingMap(producer.getOutputOperand(0));
283   for (OpOperand *opOperand : producer.getInputOperands()) {
284     fusedOperands.push_back(opOperand->get());
285     // Compute indexing maps for the producer args in the fused operation.
286     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
287         opOperand, producerResultIndexMap,
288         consumer.getTiedIndexingMap(consumerOpOperand));
289     fusedIndexMaps.push_back(map);
290   }
291   // 4.b. Producer output operand/map that is fused needs to be passed if it is
292   // an "initTensor" (i.e. its value is actually read).
293   assert(producer->getNumResults() == 1 && "expected single result producer");
294   if (producer.isInitTensor(producer.getOutputOperand(0))) {
295     fusedOperands.push_back(producer.getOutputOperand(0)->get());
296     // Compute indexing maps for the producer args in the fused operation.
297     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
298         producer.getOutputOperand(0), producerResultIndexMap,
299         consumer.getTiedIndexingMap(consumerOpOperand));
300     fusedIndexMaps.push_back(map);
301   }
302   // 5. Remaining consumer's input operands/maps (drop past index
303   // `consumerIdx`).
304   for (OpOperand *opOperand :
305        llvm::make_range(std::next(it), consumerInputs.end())) {
306     fusedOperands.push_back(opOperand->get());
307     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
308   }
309   // 6. All of consumer's output operands (skip operands: added by the builder).
310   for (OpOperand *opOperand : consumer.getOutputOperands())
311     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
312   // 7. All of producer's output operands/maps except the one fused.
313   // TODO: allow fusion of multi-result producers.
314   assert(producer->getNumResults() == 1 && "expected single result producer");
315 
316   // Generate the fused op.
317   SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
318   auto fusedOp = rewriter.create<GenericOp>(
319       consumer.getLoc(), consumer->getResultTypes(),
320       /*inputs=*/fusedOperands,
321       // TODO: handle outputs.
322       consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
323       consumer.iterator_types(),
324       /*doc=*/nullptr,
325       /*library_call=*/nullptr);
326   if (!fusedOp.getShapesToLoopsMap()) {
327     // Fused op has invalid indexing maps. Typically this means something is off
328     // in the input, but going ahead here would result in verification errors.
329     // So cleanup and abort.
330     rewriter.eraseOp(fusedOp);
331     return llvm::None;
332   }
333 
334   // Construct an AffineMap from consumer loops to producer loops.
335   // consumer loop -> tensor index
336   AffineMap consumerResultIndexMap =
337       consumer.getTiedIndexingMap(consumerOpOperand);
338   // tensor index -> producer loop
339   AffineMap invProducerResultIndexMap =
340       inversePermutation(producerResultIndexMap);
341   assert(invProducerResultIndexMap &&
342          "expected producer result indexig map to be invertible");
343   // consumer loop -> producer loop
344   AffineMap consumerToProducerLoopsMap =
345       invProducerResultIndexMap.compose(consumerResultIndexMap);
346 
347   generateFusedElementwiseOpRegion(rewriter, fusedOp,
348                                    consumerToProducerLoopsMap,
349                                    consumerOpOperand, consumer.getNumLoops());
350   return SmallVector<Value>(fusedOp->getResults());
351 }
352 
353 static Optional<SmallVector<Value>>
354 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
355                    GenericOp producer, const ControlFusionFn &controlFn) {
356   if (producer->getNumResults() != 1)
357     return llvm::None;
358 
359   return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
360                                 rewriter);
361 }
362 
363 namespace {
364 /// Patterns to fuse a generic op, with the producer of its operands.
365 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
366 public:
367   FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
368                      PatternBenefit benefit = 1)
369       : OpRewritePattern<GenericOp>(context, benefit),
370         controlFn(std::move(fun)) {}
371 
372   LogicalResult matchAndRewrite(GenericOp genericOp,
373                                 PatternRewriter &rewriter) const override {
374     // Find the first operand that is defined by another generic op on tensors.
375     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
376       auto producer =
377           dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
378       if (!producer || !producer.hasTensorSemantics())
379         continue;
380       Optional<SmallVector<Value>> fusedOpResults =
381           fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
382       if (fusedOpResults) {
383         rewriter.replaceOp(genericOp, *fusedOpResults);
384         return success();
385       }
386     }
387     return failure();
388   }
389 
390 private:
391   ControlFusionFn controlFn;
392 };
393 } // namespace
394 
395 //===---------------------------------------------------------------------===//
396 // Methods and patterns that fuse reshape ops with elementwise operations by
397 // expanding the dimensionality of the elementwise operations.
398 //===---------------------------------------------------------------------===//
399 
400 /// Conditions for folding a generic operation with a reshape op by expanding
401 /// the iteration space dimensionality for tensor operations. These are
402 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
403 /// following fusion pattern.
404 ///
405 ///  Consider
406 ///
407 ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
408 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
409 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
410 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
411 ///  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
412 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
413 ///
414 ///  The reshape can be folded into the `genericOp` if its loop dimensionality
415 ///  is increased to match the result (operand) of the tensor.expand_shape.
416 ///  The indexing_map of the fused tensor in the `genericOp` and the
417 ///  reassociation map helps compute the indexing maps of the modified op.
418 ///  For the above example, based on the reassociation map it
419 ///  can be concluded that
420 ///
421 ///  - The loop used to access the first dimension of the fused tensor is split
422 ///    into two.
423 ///  - The loop used to access the second dimension of the fused tensor is kept
424 ///    as is.
425 ///  - The loop used to access the third dimension of the fused tensor is split
426 ///    into three.
427 ///
428 ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
429 ///  op, then
430 ///
431 ///   d0 -> e0, e1
432 ///   d1 -> e2, e3, e4
433 ///   d2 -> e5
434 ///
435 ///  substituting this, the generic op can be rewritten as
436 ///
437 ///  %d = linalg.generic ins(%0, %1 : )
438 ///        indexing_maps =
439 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
440 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
441 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
442 ///
443 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
444 ///  to make it consistent
445 ///
446 ///  %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
447 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
448 ///  %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
449 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
450 ///
451 ///  The added reshapes are again expanding patterns, so they will get fused
452 ///  with its producers if possible.
453 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
454                                                OpOperand *fusableOpOperand) {
455   // Is fusable only if:
456   // - All the indexing maps for operands and results are projected
457   //   permutations.
458   // - The fused tensor is not a scalar.
459   // - All the loops are parallel loops.
460   return genericOp.hasTensorSemantics() &&
461          llvm::all_of(genericOp.indexing_maps().getValue(),
462                       [](Attribute attr) {
463                         return attr.cast<AffineMapAttr>()
464                             .getValue()
465                             .isProjectedPermutation();
466                       }) &&
467          genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
468          llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
469            return attr.cast<StringAttr>().getValue() ==
470                   getParallelIteratorTypeName();
471          });
472 }
473 
474 namespace {
475 /// Information needed to expand a generic operation to fold the reshape with
476 /// it.
477 class ExpansionInfo {
478 public:
479   // Computes the mapping from original dimensions of the op to the dimensions
480   // of the expanded op given the `indexingMap` of the fused operand/result of
481   // the generic op, the `reassocationMaps` of the reshape op and the shape of
482   // the expanded op.
483   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
484                         ArrayRef<AffineMap> reassociationMaps,
485                         ArrayRef<int64_t> expandedShape,
486                         ArrayRef<int64_t> collapsedShape,
487                         PatternRewriter &rewriter);
488   unsigned getOrigOpNumDims() const { return reassociation.size(); }
489   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
490   ReassociationIndicesRef getExpandedDims(unsigned i) const {
491     return reassociation[i];
492   }
493   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
494     return expandedShapeMap[i];
495   }
496   ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
497 
498 private:
499   /// Reassociation from the dimensions in the original operation to the
500   /// dimension of the expanded operation.
501   SmallVector<ReassociationIndices> reassociation;
502   /// Mapping from extent of loops in the original operation, to the extent of
503   /// loops in the expanded operation.
504   SmallVector<SmallVector<int64_t>> expandedShapeMap;
505   /// Extent of the loop in the original operation.
506   SmallVector<int64_t> originalLoopExtent;
507   unsigned expandedOpNumDims;
508 };
509 } // namespace
510 
511 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
512                                      OpOperand *fusableOpOperand,
513                                      ArrayRef<AffineMap> reassociationMaps,
514                                      ArrayRef<int64_t> expandedShape,
515                                      ArrayRef<int64_t> collapsedShape,
516                                      PatternRewriter &rewriter) {
517   if (reassociationMaps.empty())
518     return failure();
519   AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
520 
521   SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
522   originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
523 
524   reassociation.clear();
525   expandedShapeMap.clear();
526   // Compute the number of dimension in the expanded op that correspond to each
527   // dimension of the original op.
528   SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
529   expandedShapeMap.resize(fusedIndexMap.getNumDims());
530   for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
531     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
532     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
533     numExpandedDims[pos] = foldedDims.getNumResults();
534     ArrayRef<int64_t> shape =
535         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
536     expandedShapeMap[pos].assign(shape.begin(), shape.end());
537   }
538   // The remaining dimensions remain the same.
539   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
540     if (expandedShapeMap[i].empty())
541       expandedShapeMap[i] = {originalLoopExtent[i]};
542 
543   // Compute reassociation map from the original op to the expanded op.
544   unsigned sum = 0;
545   reassociation.reserve(fusedIndexMap.getNumDims());
546   for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
547     auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
548     reassociation.emplace_back(seq.begin(), seq.end());
549     sum += numFoldedDim.value();
550   }
551   expandedOpNumDims = sum;
552   return success();
553 }
554 
555 /// Epanding the body of a linalg operation requires adaptations of the accessed
556 /// loop indices. Specifically, access of indices in the original operation need
557 /// to be replaced with linearizations of indices in the expanded op. That
558 /// requires the shape of the expanded dimensions to be static (at least all but
559 /// the most significant). For now check that these are all statically sized.
560 /// Note that this could be extended to handle dynamic case, but the
561 /// implementation below uses `affine.apply` which seems to have issues when the
562 /// shapes are not static.
563 static LogicalResult isGenericOpExpandable(GenericOp genericOp,
564                                            const ExpansionInfo &expansionInfo,
565                                            PatternRewriter &rewriter) {
566   if (!genericOp.hasIndexSemantics())
567     return success();
568   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
569     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
570     if (expandedShape.size() == 1)
571       continue;
572     for (int64_t shape : expandedShape.drop_front()) {
573       if (ShapedType::isDynamic(shape)) {
574         return rewriter.notifyMatchFailure(
575             genericOp, "cannot expand due to index semantics and dynamic dims");
576       }
577     }
578   }
579   return success();
580 }
581 
582 /// Return the indexing map to use in the expanded op for a given the
583 /// `indexingMap` of the original operation.
584 static AffineMap
585 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
586                            const ExpansionInfo &expansionInfo) {
587   SmallVector<AffineExpr> newExprs;
588   for (AffineExpr expr : indexingMap.getResults()) {
589     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
590     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
591         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
592           return builder.getAffineDimExpr(static_cast<unsigned>(v));
593         }));
594     newExprs.append(expandedExprs.begin(), expandedExprs.end());
595   }
596   return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
597                         indexingMap.getNumSymbols(), newExprs,
598                         builder.getContext());
599 }
600 
601 /// Return the type of the operand/result to use in the expanded op given the
602 /// type in the original op.
603 static RankedTensorType getExpandedType(RankedTensorType originalType,
604                                         AffineMap indexingMap,
605                                         const ExpansionInfo &expansionInfo) {
606   SmallVector<int64_t> expandedShape;
607   for (AffineExpr expr : indexingMap.getResults()) {
608     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
609     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
610     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
611   }
612   return RankedTensorType::get(expandedShape, originalType.getElementType());
613 }
614 
615 /// Returns the reassociation maps to use in the `tensor.expand_shape`
616 /// operation to convert the operands of the original operation to operands of
617 /// the expanded operation. The same method is used to compute the
618 /// `tensor.collapse_shape` used to collapse the result of the expanded
619 /// op to get the value that can replace all uses of the results of the original
620 /// op.
621 static SmallVector<ReassociationIndices>
622 getReassociationForExpansion(AffineMap indexingMap,
623                              const ExpansionInfo &expansionInfo) {
624   SmallVector<ReassociationIndices> reassociation;
625   unsigned numReshapeDims = 0;
626   for (AffineExpr expr : indexingMap.getResults()) {
627     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
628     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
629     SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
630         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
631     reassociation.emplace_back(std::move(indices));
632     numReshapeDims += numExpandedDims;
633   }
634   return reassociation;
635 }
636 
637 /// Update the body of an expanded linalg operation having index semantics. The
638 /// indices of the original operation need to be recovered by linearizing the
639 /// indices of the correspoding dimensions of the expanded operation. For now it
640 /// is assumed that the shapes of the expanded operation needed for
641 /// linearization are static.
642 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
643                                           Location loc, Region &fusedRegion,
644                                           const ExpansionInfo &expansionInfo) {
645   // Replace the original indices by the linearization of the expanded indices.
646   for (IndexOp indexOp :
647        llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
648     ArrayRef<int64_t> expandedDims =
649         expansionInfo.getExpandedDims(indexOp.dim());
650     assert(!expandedDims.empty() && "expected valid expansion info");
651 
652     // Skip index operations that are not affected by the expansion.
653     if (expandedDims.size() == 1 &&
654         expandedDims.front() == (int64_t)indexOp.dim())
655       continue;
656 
657     // Linearize the expanded indices of the original index dimension.
658     OpBuilder::InsertionGuard guard(rewriter);
659     rewriter.setInsertionPointAfter(indexOp);
660     ArrayRef<int64_t> expandedDimsShape =
661         expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
662     SmallVector<Value> expandedIndices;
663     expandedIndices.reserve(expandedDims.size() - 1);
664     llvm::transform(
665         expandedDims.drop_front(), std::back_inserter(expandedIndices),
666         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
667     Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
668     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
669       assert(!ShapedType::isDynamic(std::get<0>(it)));
670       AffineExpr idx, acc;
671       bindDims(rewriter.getContext(), idx, acc);
672       newIndex = rewriter.create<AffineApplyOp>(
673           indexOp.getLoc(), idx + acc * std::get<0>(it),
674           ValueRange{std::get<1>(it), newIndex});
675     }
676     rewriter.replaceOp(indexOp, newIndex);
677   }
678 }
679 
680 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
681 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
682 /// that those conditions have been satisfied.
683 static Optional<SmallVector<Value>>
684 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
685                            OpOperand *fusableOpOperand,
686                            PatternRewriter &rewriter) {
687   assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
688          "preconditions for fuse operation failed");
689   // Check if reshape is expanding or collapsing.
690   auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
691   auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
692   bool isExpanding = (expandingReshapeOp != nullptr);
693   RankedTensorType expandedType = isExpanding
694                                       ? expandingReshapeOp.getResultType()
695                                       : collapsingReshapeOp.getSrcType();
696   RankedTensorType collapsedType = isExpanding
697                                        ? expandingReshapeOp.getSrcType()
698                                        : collapsingReshapeOp.getResultType();
699 
700   ExpansionInfo expansionInfo;
701   if (failed(expansionInfo.compute(
702           genericOp, fusableOpOperand,
703           isExpanding ? expandingReshapeOp.getReassociationMaps()
704                       : collapsingReshapeOp.getReassociationMaps(),
705           expandedType.getShape(), collapsedType.getShape(), rewriter)))
706     return llvm::None;
707 
708   if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
709     return llvm::None;
710 
711   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
712       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
713         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
714       }));
715 
716   SmallVector<Value> expandedOpOperands;
717   expandedOpOperands.reserve(genericOp.getNumInputs());
718   for (OpOperand *opOperand : genericOp.getInputOperands()) {
719     if (opOperand == fusableOpOperand) {
720       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
721                                                : collapsingReshapeOp.src());
722       continue;
723     }
724     if (genericOp.isInputTensor(opOperand)) {
725       AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
726       auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
727       RankedTensorType expandedOperandType =
728           getExpandedType(opOperandType, indexingMap, expansionInfo);
729       if (expandedOperandType != opOperand->get().getType()) {
730         // Reshape the operand to get the right type.
731         SmallVector<ReassociationIndices> reassociation =
732             getReassociationForExpansion(indexingMap, expansionInfo);
733         if (failed(reshapeLikeShapesAreCompatible(
734                 [&](const Twine &msg) {
735                   return rewriter.notifyMatchFailure(genericOp, msg);
736                 },
737                 opOperandType.getShape(), expandedOperandType.getShape(),
738                 reassociation,
739                 /*isExpandingReshape=*/true)))
740           return llvm::None;
741         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
742             genericOp.getLoc(), expandedOperandType, opOperand->get(),
743             reassociation));
744         continue;
745       }
746     }
747     expandedOpOperands.push_back(opOperand->get());
748   }
749 
750   Location loc = genericOp.getLoc();
751   SmallVector<Value> outputs;
752   for (OpOperand *opOperand : genericOp.getOutputOperands()) {
753     AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
754     auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
755     RankedTensorType expandedOutputType =
756         getExpandedType(opOperandType, indexingMap, expansionInfo);
757     if (expandedOutputType != opOperand->get().getType()) {
758       SmallVector<ReassociationIndices> reassociation =
759           getReassociationForExpansion(indexingMap, expansionInfo);
760       if (failed(reshapeLikeShapesAreCompatible(
761               [&](const Twine &msg) {
762                 return rewriter.notifyMatchFailure(genericOp, msg);
763               },
764               opOperandType.getShape(), expandedOutputType.getShape(),
765               reassociation,
766               /*isExpandingReshape=*/true)))
767         return llvm::None;
768       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
769           genericOp.getLoc(), expandedOutputType, opOperand->get(),
770           reassociation));
771     }
772   }
773 
774   // The iterator types of the expanded op are all parallel.
775   SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
776                                        getParallelIteratorTypeName());
777 
778   TypeRange resultTypes = ValueRange(outputs).getTypes();
779   auto fusedOp =
780       rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
781                                  /*inputs=*/expandedOpOperands, outputs,
782                                  expandedOpIndexingMaps, iteratorTypes);
783   Region &fusedRegion = fusedOp->getRegion(0);
784   Region &originalRegion = genericOp->getRegion(0);
785   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
786 
787   // Update the index accesses after the expansion.
788   updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
789 
790   // Reshape the result values to their original shape if this is a collapsing
791   // reshape folded into its consumer.
792   SmallVector<Value> resultVals;
793   for (OpResult opResult : genericOp->getOpResults()) {
794     int64_t resultNumber = opResult.getResultNumber();
795     if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
796       SmallVector<ReassociationIndices> reassociation =
797           getReassociationForExpansion(
798               genericOp.getTiedIndexingMap(
799                   genericOp.getOutputOperand(resultNumber)),
800               expansionInfo);
801       resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
802           genericOp.getLoc(), opResult.getType(),
803           fusedOp->getResult(resultNumber), reassociation));
804     } else {
805       resultVals.push_back(fusedOp->getResult(resultNumber));
806     }
807   }
808   // Assuming a single result.
809   return resultVals;
810 }
811 
812 namespace {
813 
814 /// Pattern to fuse a tensor.collapse_shape op with its consumer generic op,
815 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
816 /// in the consumer is expanded.
817 class FoldWithProducerReshapeOpByExpansion
818     : public OpRewritePattern<GenericOp> {
819 public:
820   FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
821                                        ControlFusionFn foldReshapes,
822                                        PatternBenefit benefit = 1)
823       : OpRewritePattern<GenericOp>(context, benefit),
824         controlFoldingReshapes(std::move(foldReshapes)) {}
825 
826   LogicalResult matchAndRewrite(GenericOp genericOp,
827                                 PatternRewriter &rewriter) const override {
828     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
829       tensor::CollapseShapeOp reshapeOp =
830           opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
831       if (!reshapeOp)
832         continue;
833       // Fold only if
834       // - The tensor reshape op is folding.
835       // - All constraints of fusing with reshape by expansion are met.
836       if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
837           (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
838         continue;
839 
840       Optional<SmallVector<Value>> replacementValues =
841           fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
842       if (!replacementValues)
843         return failure();
844       rewriter.replaceOp(genericOp, replacementValues.getValue());
845       return success();
846     }
847     return failure();
848   }
849 
850 private:
851   ControlFusionFn controlFoldingReshapes;
852 };
853 
854 /// Pattern to fold a tensor.expand_shape op with its producer generic op
855 /// by expanding the dimensionality of the loop in the producer op.
856 struct FoldReshapeWithGenericOpByExpansion
857     : public OpRewritePattern<tensor::ExpandShapeOp> {
858 
859   FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
860                                       ControlFusionFn foldReshapes,
861                                       PatternBenefit benefit = 1)
862       : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
863         controlFoldingReshapes(std::move(foldReshapes)) {}
864 
865   LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
866                                 PatternRewriter &rewriter) const override {
867     // Fold only if all constraints of fusing with reshape by expansion are met.
868     GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
869     if (!producer || producer.getNumOutputs() != 1 ||
870         !isFusableWithReshapeByDimExpansion(producer,
871                                             producer.getOutputOperand(0)) ||
872         !controlFoldingReshapes(producer->getResult(0),
873                                 reshapeOp->getOpOperand(0)))
874       return failure();
875     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
876         producer, reshapeOp, producer.getOutputOperand(0), rewriter);
877     if (!replacementValues)
878       return failure();
879     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
880     return success();
881   }
882 
883 private:
884   ControlFusionFn controlFoldingReshapes;
885 };
886 } // namespace
887 
888 //===---------------------------------------------------------------------===//
889 // Methods and patterns to fuse reshape with linalg.generic operations by
890 // contraction of dimensions.
891 //===---------------------------------------------------------------------===//
892 
893 /// For a given list of indices in the range of the `indexingMap` that are
894 /// folded, return the indices of the corresponding domain. Return `llvm::None`
895 /// on failure. Ensures that all the elements of the returned reassociation are
896 /// distinct.
897 static ReassociationIndices
898 getDomainReassociation(AffineMap indexingMap,
899                        ReassociationIndicesRef rangeReassociation) {
900   assert(indexingMap.isProjectedPermutation() &&
901          "expected projected permutation");
902 
903   ReassociationIndices domainReassociation = llvm::to_vector<4>(
904       llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
905         return indexingMap.getResults()[pos]
906             .cast<AffineDimExpr>()
907             .getPosition();
908       }));
909   // The projected permutation semantics ensures that there is no repetition of
910   // the domain indices.
911   return domainReassociation;
912 }
913 
914 /// For a given `dimSequence`, check if the sequence is conserved in the
915 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
916 /// Non-existence of the sequence returns true as well.
917 static bool isDimSequencePreserved(AffineMap indexingMap,
918                                    ReassociationIndicesRef dimSequence) {
919   assert(!dimSequence.empty() &&
920          "expected non-empty list for dimension sequence");
921   assert(indexingMap.isProjectedPermutation() &&
922          "expected indexing map to be projected permutation");
923 
924   llvm::SmallDenseSet<unsigned, 4> sequenceElements;
925   sequenceElements.insert(dimSequence.begin(), dimSequence.end());
926 
927   unsigned dimSequenceStart = dimSequence[0];
928   for (const auto &expr : enumerate(indexingMap.getResults())) {
929     unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition();
930     // 1.  Check if this start of the sequence.
931     if (dimInMapStart == dimSequenceStart) {
932       if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
933         return false;
934       // 1a. Check if sequence is preserved.
935       for (const auto &dimInSequence : enumerate(dimSequence)) {
936         unsigned dimInMap =
937             indexingMap.getResult(expr.index() + dimInSequence.index())
938                 .cast<AffineDimExpr>()
939                 .getPosition();
940         if (dimInMap != dimInSequence.value())
941           return false;
942       }
943       // Found the sequence. Projected permutation
944       // enforces that all AffineDimExprs in the result are unique, so no
945       // further checks are needed.
946       return true;
947     }
948     // 2. If position in the expr (which is of type AffineDimExpr) is part
949     // of sequence, return false here. This implies the entire sequence does not
950     // exist in the indexing map.
951     if (sequenceElements.count(dimInMapStart))
952       return false;
953   }
954   // 3. No element of sequence found. Return true.
955   return true;
956 }
957 
958 // Return the list of dimensions of the iteration domain that can be
959 // collapsed to allow for fusion with the a producer that is an expand_shape
960 // operation. If all dimensions created by expansion can be collapsed in the
961 // iteration space then the reshape is defunct.
962 //
963 // Example:
964 //
965 // ```mlir
966 // #map = affine_map<(d0, d1) -> (d0, d1)>
967 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
968 // %2 = linalg.init_tensor [..] : tensor<?x4xf32>
969 // %3 = linalg.generic {
970 //     indexing_maps = [#map, #map],
971 //     iterator_types = ["parallel" ,"parallel"]}
972 //     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
973 // ```
974 //
975 // can be fused by collapsing the dimensions of the iteration space.
976 //
977 // ```mlir
978 // #map = affine_map<(d0) -> (d0)>
979 // %2 = linalg.init_tensor [..] : tensor<?xf32>
980 // %3 = linalg.generic {
981 //     indexing_maps = [#map, #map],
982 //     iterator_types = ["parallel"]}
983 //     ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
984 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
985 // ```
986 //
987 // In the following example,
988 //
989 // ```mlir
990 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
991 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
992 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
993 // %2 = linalg.init_tensor [..] : tensor<4x?xf32>
994 // %2 = linalg.generic {
995 //     indexing_maps = [#map0, #map1],
996 //     iterator_types = ["parallel" ,"parallel"]}
997 //     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
998 // ```
999 //
1000 // the reshape cannot be fused with the generic op by collapsing the op
1001 // dimensions since the indexing maps will have to contain mods and divs
1002 // to preserve the accesses pattern. When no dimensions of the iteration
1003 // space are collapsable and empty vector is returned.
1004 static SmallVector<ReassociationIndices>
1005 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1006                                  ArrayRef<ReassociationIndices> reassociation) {
1007   // Some basic checks for this fusion to be valid.
1008   if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1)
1009     return {};
1010 
1011   if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) {
1012         return map.isProjectedPermutation();
1013       })) {
1014     return {};
1015   }
1016 
1017   // Compute all the loops with the reduction iterator types.
1018   SmallVector<int64_t> reductionDims;
1019   for (const auto &iteratorType : llvm::enumerate(genericOp.iterator_types())) {
1020     if (isReductionIterator(iteratorType.value())) {
1021       reductionDims.push_back(iteratorType.index());
1022     }
1023   }
1024 
1025   llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1026   AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand);
1027   auto iteratorTypes = genericOp.iterator_types().getValue();
1028   SmallVector<ReassociationIndices> iterationSpaceReassociation;
1029   for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1030     assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1031 
1032     // Ignore dims that are not folded.
1033     if (foldedRangeDims.size() == 1)
1034       continue;
1035 
1036     ReassociationIndices foldedIterationSpaceDims =
1037         getDomainReassociation(indexingMap, foldedRangeDims);
1038 
1039     // Check that the folded iteration dims do not contain already processed
1040     // dims.
1041     if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1042           return processedIterationDims.count(dim);
1043         }))
1044       continue;
1045 
1046     // Check that all folded iterator types are all parallel or all reductions.
1047     Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
1048     if (!isParallelIterator(startIteratorType) &&
1049         !isReductionIterator(startIteratorType))
1050       continue;
1051     if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1052           return iteratorTypes[dim] != startIteratorType;
1053         }))
1054       continue;
1055 
1056     // If the folded dimensions correspond to a "reduction" iterator type,
1057     // the folded dimensions need to be "in-order". Strictly speaking this is
1058     // not necessary, for reductions that are associative and commutative,  but
1059     // using a more strict definition of reduction for now.
1060     if (isReductionIterator(startIteratorType)) {
1061       bool isContiguous = false;
1062       for (const auto &startDim : llvm::enumerate(reductionDims)) {
1063         // Move window in `reductionDims` to start of the folded iteration dims.
1064         if (startDim.value() != foldedIterationSpaceDims[0])
1065           continue;
1066         // If sizes doesnt match, trivial not contiguous. This condition should
1067         // not be hit.
1068         if (startDim.index() + foldedIterationSpaceDims.size() >
1069             reductionDims.size())
1070           break;
1071         // Check that the contiguity is maintained.
1072         isContiguous = true;
1073         for (const auto &foldedDim :
1074              llvm::enumerate(foldedIterationSpaceDims)) {
1075           if (reductionDims[foldedDim.index() + startDim.index()] !=
1076               foldedDim.value()) {
1077             isContiguous = false;
1078             break;
1079           }
1080         }
1081         break;
1082       }
1083       if (!isContiguous)
1084         continue;
1085     }
1086 
1087     // Check that the sequence is preserved in all indexing maps.
1088     if (llvm::any_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) {
1089           return !isDimSequencePreserved(indexingMap, foldedIterationSpaceDims);
1090         }))
1091       continue;
1092 
1093     processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1094                                   foldedIterationSpaceDims.end());
1095     iterationSpaceReassociation.emplace_back(
1096         std::move(foldedIterationSpaceDims));
1097   }
1098 
1099   return iterationSpaceReassociation;
1100 }
1101 
1102 /// Helper class to carry state while collapsing the `linalg.generic` op.
1103 namespace {
1104 class CollapsingInfo {
1105 public:
1106   LogicalResult initialize(unsigned origNumLoops,
1107                            ArrayRef<ReassociationIndices> foldedIterationDims) {
1108     llvm::SmallDenseSet<int64_t, 4> processedDims;
1109     // Find all the dims that are folded.
1110     for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1111       if (foldedIterationDim.empty())
1112         continue;
1113       // If the folded dims contain dims already folded, that's illegal
1114       // specification. Repetition within a list is also illegal.
1115       for (auto dim : foldedIterationDim) {
1116         if (dim >= origNumLoops)
1117           return failure();
1118         if (processedDims.count(dim))
1119           return failure();
1120         processedDims.insert(dim);
1121       }
1122       collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1123                                                    foldedIterationDim.end());
1124     }
1125     if (processedDims.size() > origNumLoops)
1126       return failure();
1127 
1128     // Add all the preserved dims of the original op as single
1129     // elements to `collapsedOpToOrigOpIterationDim`.
1130     for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1131       if (processedDims.count(dim))
1132         continue;
1133       collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1134     }
1135 
1136     llvm::sort(collapsedOpToOrigOpIterationDim,
1137                [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
1138                  return lhs[0] < rhs[0];
1139                });
1140     origOpToCollapsedOpIterationDim.resize(origNumLoops);
1141     for (const auto &foldedDims :
1142          llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1143       for (const auto &dim : enumerate(foldedDims.value()))
1144         origOpToCollapsedOpIterationDim[dim.value()] =
1145             std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1146     }
1147     return success();
1148   }
1149 
1150   /// Return mapping from collapsed loop domain to original loop domain.
1151   ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1152     return collapsedOpToOrigOpIterationDim;
1153   }
1154 
1155   /// Return mapping from original loop domain to collapsed loop domain. The
1156   /// mapping is a pair. First value is the dimension in the collapsed loop that
1157   /// the original loop is mapped to. Second is the relative position in folded
1158   /// list of this domain. For example if the original loop domain is 3D, and
1159   /// the collapsed loop domain is folding all of it, i.e.
1160   ///
1161   /// ```
1162   /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1163   /// ```
1164   ///
1165   /// then
1166   ///
1167   /// ```
1168   ///  origOpToCollapsedOpMapping[0] = {0, 0};
1169   ///  origOpToCollapsedOpMapping[1] = {0, 1};
1170   ///  origOpToCollapsedOpMapping[2] = {0, 2};
1171   ///  origOpToCollapsedOpMapping[3] = {1, 0};
1172   ///  origOpToCollapsedOpMapping[4] = {1, 1};
1173   /// ```
1174   ///
1175   ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1176     return origOpToCollapsedOpIterationDim;
1177   }
1178 
1179   /// Return the collapsed op iteration domain rank.
1180   unsigned getCollapsedOpIterationRank() const {
1181     return collapsedOpToOrigOpIterationDim.size();
1182   }
1183 
1184 private:
1185   /// Map from the iteration domain index in collapsed op to the iteration
1186   /// domain indices in the original op.
1187   SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1188 
1189   /// Map from iteration domain index in the original op to the iteration domain
1190   /// index in the collapsed op.
1191   SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1192 };
1193 } // namespace
1194 
1195 /// Get the iterator types for the collapsed operation given the original
1196 /// iterator types and collapsed dimensions.
1197 static SmallVector<StringRef>
1198 getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
1199                             const CollapsingInfo &collapsingInfo) {
1200   SmallVector<StringRef> collapsedIteratorTypes;
1201   for (ReassociationIndicesRef foldedIterDims :
1202        collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1203     assert(!foldedIterDims.empty() &&
1204            "reassociation indices expected to have non-empty sets");
1205     // Just pick the iterator type of the first folded dim. Pre-condition checks
1206     // expected to have checked that iterator types of all folded dimensions are
1207     // the same.
1208     collapsedIteratorTypes.push_back(
1209         iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue());
1210   }
1211   return collapsedIteratorTypes;
1212 }
1213 
1214 /// Compute the indexing map in the collapsed op that corresponds to the given
1215 /// `indexingMap` of the original operation.
1216 static AffineMap
1217 getCollapsedOpIndexingMap(AffineMap indexingMap,
1218                           const CollapsingInfo &collapsingInfo) {
1219   MLIRContext *context = indexingMap.getContext();
1220   assert(indexingMap.isProjectedPermutation() &&
1221          "expected indexing map to be projected permutation");
1222   SmallVector<AffineExpr> resultExprs;
1223   auto origOpToCollapsedOpMapping =
1224       collapsingInfo.getOrigOpToCollapsedOpMapping();
1225   for (auto expr : indexingMap.getResults()) {
1226     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
1227     // If the dim is not the first of the collapsed dim, do nothing.
1228     if (origOpToCollapsedOpMapping[dim].second != 0)
1229       continue;
1230     // The next n-dims are guaranteed to be collapsed. So just use the
1231     // iteration dimension of the collapsed op.
1232     resultExprs.push_back(
1233         getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1234   }
1235   return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1236                         resultExprs, context);
1237 }
1238 
1239 /// Return the `reassociation` indices to use to collapse the operand when the
1240 /// iteration space of a generic op is collapsed.
1241 static SmallVector<ReassociationIndices>
1242 getOperandReassociation(AffineMap indexingMap,
1243                         const CollapsingInfo &collapsingInfo) {
1244   unsigned counter = 0;
1245   SmallVector<ReassociationIndices> operandReassociation;
1246   auto origOpToCollapsedOpMapping =
1247       collapsingInfo.getOrigOpToCollapsedOpMapping();
1248   auto collapsedOpToOrigOpMapping =
1249       collapsingInfo.getCollapsedOpToOrigOpMapping();
1250   while (counter < indexingMap.getNumResults()) {
1251     unsigned dim =
1252         indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition();
1253     if (origOpToCollapsedOpMapping[dim].second == 0) {
1254       // This is the start of a collapsed dimensions of the iteration that
1255       // is gauranteed to be preserved in the indexing map. The number of folded
1256       // dims is obtained from the collapsed op to original op mapping.
1257       unsigned numFoldedDims =
1258           collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1259               .size();
1260       auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1261       operandReassociation.emplace_back(range.begin(), range.end());
1262       counter += numFoldedDims;
1263     }
1264   }
1265   return operandReassociation;
1266 }
1267 
1268 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1269 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
1270                                    OpOperand *opOperand,
1271                                    const CollapsingInfo &collapsingInfo,
1272                                    OpBuilder &builder) {
1273   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
1274   SmallVector<ReassociationIndices> operandReassociation =
1275       getOperandReassociation(indexingMap, collapsingInfo);
1276 
1277   // If the number of entries in the reassocation for the operand is same as the
1278   // number of results of the indexing map, then nothing to do for this operand.
1279   Value operand = opOperand->get();
1280   if (operandReassociation.size() == indexingMap.getNumResults())
1281     return operand;
1282 
1283   // Insert a reshape to collapse the dimensions.
1284   auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
1285       loc, operand, operandReassociation);
1286   return reshapeOp.getResult();
1287 }
1288 
1289 /// Modify the `linalg.index` operations in the original generic op, to its
1290 /// value in the collapsed operation.
1291 void generateCollapsedIndexingRegion(Location loc, Block *block,
1292                                      const CollapsingInfo &collapsingInfo,
1293                                      ValueRange loopRange,
1294                                      PatternRewriter &rewriter) {
1295   OpBuilder::InsertionGuard g(rewriter);
1296   rewriter.setInsertionPointToStart(block);
1297 
1298   // Collect all the original index ops.
1299   auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1300 
1301   // For each folded dimension list resolve the original induction variable
1302   // values in terms of the folded dimension induction variable.
1303   //   i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1304   // can be inverted to
1305   //   i2 = i_{folded} % d2
1306   //   i1 = (i_{folded} / d2) % d1
1307   //   i0 = i_{folded} / (d1 * d2)
1308   llvm::DenseMap<unsigned, Value> indexReplacementVals;
1309   for (auto &foldedDims :
1310        enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1311     ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1312     Value newIndexVal =
1313         rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1314     for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1315       indexReplacementVals[dim] =
1316           rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1317       newIndexVal =
1318           rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1319     }
1320     indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1321   }
1322 
1323   for (auto indexOp : indexOps) {
1324     auto dim = indexOp.dim();
1325     rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1326   }
1327 }
1328 
1329 /// Implementation of fusion with reshape operation by collapsing dimensions.
1330 static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
1331     GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
1332     OpOperand *fusableOpOperand, PatternRewriter &rewriter) {
1333   // Bail on trivial no-op cases.
1334   if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1335       llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1336         return foldedDims.size() <= 1;
1337       }))
1338     return failure();
1339 
1340   CollapsingInfo collapsingInfo;
1341   if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
1342                                        foldedIterationDims))) {
1343     return rewriter.notifyMatchFailure(
1344         genericOp, "illegal to collapse specified dimensions");
1345   }
1346 
1347   // Get the iterator types for the operand.
1348   SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes(
1349       genericOp.iterator_types().getValue(), collapsingInfo);
1350 
1351   // Get the indexing maps.
1352   auto indexingMaps = llvm::to_vector(
1353       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap map) {
1354         return getCollapsedOpIndexingMap(map, collapsingInfo);
1355       }));
1356 
1357   Location loc = genericOp->getLoc();
1358 
1359   // Get the input operands.
1360   auto inputOperands = llvm::to_vector(
1361       llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) {
1362         return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
1363                                      rewriter);
1364       }));
1365 
1366   // Get the output operands and result types.
1367   SmallVector<Type> resultTypes;
1368   SmallVector<Value> outputOperands;
1369   resultTypes.reserve(genericOp.getNumOutputs());
1370   outputOperands.reserve(genericOp.getNumOutputs());
1371   for (OpOperand *output : genericOp.getOutputOperands()) {
1372     Value newOutput =
1373         getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter);
1374     outputOperands.push_back(newOutput);
1375     resultTypes.push_back(newOutput.getType());
1376   }
1377 
1378   // Create the generic op.
1379   auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
1380       loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1381       iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1382   Block *origOpBlock = &genericOp->getRegion(0).front();
1383   Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
1384   rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1385                        collapsedOpBlock->getArguments());
1386 
1387   if (collapsedGenericOp.hasIndexSemantics()) {
1388     // Collect the loop range of the generic op.
1389     OpBuilder::InsertionGuard g(rewriter);
1390     rewriter.setInsertionPoint(collapsedGenericOp);
1391     SmallVector<Range> loopRanges =
1392         cast<LinalgOp>(genericOp.getOperation())
1393             .createLoopRanges(rewriter, genericOp.getLoc());
1394     assert(llvm::all_of(loopRanges,
1395                         [](Range range) {
1396                           return matchPattern(range.offset, m_Zero()) &&
1397                                  matchPattern(range.stride, m_One());
1398                         }) &&
1399            "expected all loop ranges to have zero start and unit stride");
1400     SmallVector<Value> loopBound = llvm::to_vector(
1401         llvm::map_range(loopRanges, [](Range range) { return range.size; }));
1402     generateCollapsedIndexingRegion(loc,
1403                                     &collapsedGenericOp->getRegion(0).front(),
1404                                     collapsingInfo, loopBound, rewriter);
1405   }
1406 
1407   // Insert expanding reshape for the result to get back the original result
1408   // type.
1409   SmallVector<Value> results;
1410   for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
1411     Value collapsedOpResult =
1412         collapsedGenericOp->getResult(originalResult.index());
1413     auto originalResultType =
1414         originalResult.value().getType().cast<ShapedType>();
1415     auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
1416     if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1417       AffineMap indexingMap =
1418           genericOp.getTiedIndexingMapForResult(originalResult.value());
1419       SmallVector<ReassociationIndices> reassociation =
1420           getOperandReassociation(indexingMap, collapsingInfo);
1421       Value result = rewriter.create<tensor::ExpandShapeOp>(
1422           loc, originalResultType, collapsedOpResult, reassociation);
1423       results.push_back(result);
1424     } else {
1425       results.push_back(collapsedOpResult);
1426     }
1427   }
1428   return results;
1429 }
1430 
1431 namespace {
1432 
1433 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1434 /// contracting dimensions of the loop.
1435 class FoldWithProducerReshapeOpByCollapsing
1436     : public OpRewritePattern<GenericOp> {
1437 public:
1438   FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1439                                         ControlFusionFn foldReshapes,
1440                                         PatternBenefit benefit = 1)
1441       : OpRewritePattern<GenericOp>(context, benefit),
1442         controlFoldingReshapes(std::move(foldReshapes)) {}
1443 
1444   LogicalResult matchAndRewrite(GenericOp genericOp,
1445                                 PatternRewriter &rewriter) const override {
1446     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1447       tensor::ExpandShapeOp reshapeOp =
1448           opOperand->get().getDefiningOp<tensor::ExpandShapeOp>();
1449       if (!reshapeOp)
1450         continue;
1451 
1452       SmallVector<ReassociationIndices> collapsableIterationDims =
1453           getCollapsableIterationSpaceDims(genericOp, opOperand,
1454                                            reshapeOp.getReassociationIndices());
1455       if (collapsableIterationDims.empty() ||
1456           !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
1457         continue;
1458       }
1459 
1460       Optional<SmallVector<Value>> replacements =
1461           collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
1462                                          opOperand, rewriter);
1463       if (!replacements) {
1464         return rewriter.notifyMatchFailure(
1465             genericOp, "failed to do the fusion by collapsing transformation");
1466       }
1467 
1468       rewriter.replaceOp(genericOp, replacements.getValue());
1469       return success();
1470     }
1471     return failure();
1472   }
1473 
1474 private:
1475   ControlFusionFn controlFoldingReshapes;
1476 };
1477 } // namespace
1478 
1479 //===---------------------------------------------------------------------===//
1480 // Methods and patterns that fuse constants with linalg.generic operations.
1481 //===---------------------------------------------------------------------===//
1482 
1483 namespace {
1484 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1485 /// handle cases where the constant is not single-valued.
1486 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1487 public:
1488   FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
1489       : OpRewritePattern<GenericOp>(context, benefit) {}
1490 
1491   LogicalResult matchAndRewrite(GenericOp genericOp,
1492                                 PatternRewriter &rewriter) const override {
1493     if (!genericOp.hasTensorSemantics())
1494       return failure();
1495     for (OpOperand *opOperand : genericOp.getInputOperands()) {
1496       Operation *def = opOperand->get().getDefiningOp();
1497       Attribute constantAttr;
1498       auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1499         {
1500           DenseElementsAttr splatAttr;
1501           if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1502               splatAttr.isSplat() &&
1503               splatAttr.getType().getElementType().isIntOrFloat()) {
1504             constantAttr = splatAttr.getSplatValue<Attribute>();
1505             return true;
1506           }
1507         }
1508         {
1509           IntegerAttr intAttr;
1510           if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1511             constantAttr = intAttr;
1512             return true;
1513           }
1514         }
1515         {
1516           FloatAttr floatAttr;
1517           if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1518             constantAttr = floatAttr;
1519             return true;
1520           }
1521         }
1522         return false;
1523       };
1524 
1525       auto resultValue = opOperand->get().dyn_cast<OpResult>();
1526       if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1527         continue;
1528 
1529       // The operands and the indexing_maps of the fused operation the same as
1530       // the operands and indexing_maps of the generic operations with the
1531       // values at the constant index dropped.
1532       SmallVector<AffineMap> fusedIndexMaps;
1533       SmallVector<Value> fusedOperands;
1534       SmallVector<Location> fusedLocs{genericOp.getLoc()};
1535       fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1536       fusedOperands.reserve(genericOp.getNumInputs());
1537       fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
1538       for (OpOperand *inputOperand : genericOp.getInputOperands()) {
1539         if (inputOperand == opOperand)
1540           continue;
1541         Value inputValue = inputOperand->get();
1542         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1543         fusedOperands.push_back(inputValue);
1544         fusedLocs.push_back(inputValue.getLoc());
1545       }
1546       for (OpOperand *outputOperand : genericOp.getOutputOperands())
1547         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1548 
1549       // Check if the operation shapes to loops map is computable.
1550       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1551         return rewriter.notifyMatchFailure(
1552             genericOp, "fused op loop bound computation failed");
1553       }
1554 
1555       // Create a constant scalar value from the splat constant.
1556       Value scalarConstant = rewriter.create<arith::ConstantOp>(
1557           def->getLoc(), constantAttr, constantAttr.getType());
1558 
1559       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1560       auto fusedOp = rewriter.create<GenericOp>(
1561           rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1562           /*inputs=*/fusedOperands,
1563           /*outputs=*/outputOperands,
1564           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1565           genericOp.iterator_types(),
1566           /*doc=*/nullptr,
1567           /*library_call=*/nullptr);
1568 
1569       // Map the block argument corresponding to the replaced argument with the
1570       // scalar constant.
1571       Region &region = genericOp->getRegion(0);
1572       Block &entryBlock = *region.begin();
1573       BlockAndValueMapping mapping;
1574       mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1575                   scalarConstant);
1576       Region &fusedRegion = fusedOp->getRegion(0);
1577       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1578                                  mapping);
1579       rewriter.replaceOp(genericOp, fusedOp->getResults());
1580       return success();
1581     }
1582     return failure();
1583   }
1584 };
1585 
1586 } // namespace
1587 
1588 //===---------------------------------------------------------------------===//
1589 // Miscellaneous patterns that help fusion.
1590 //===---------------------------------------------------------------------===//
1591 
1592 namespace {
1593 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
1594 /// the value of the `outs` operand is not used within the op.  This is only
1595 /// implemented for `linalg.generic` operations for now, but should hold for all
1596 /// linalg structured ops.
1597 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1598   using OpRewritePattern<GenericOp>::OpRewritePattern;
1599 
1600   LogicalResult matchAndRewrite(GenericOp op,
1601                                 PatternRewriter &rewriter) const override {
1602     rewriter.startRootUpdate(op);
1603     bool modifiedOutput = false;
1604     Location loc = op.getLoc();
1605     for (OpOperand *opOperand : op.getOutputOperands()) {
1606       if (!op.payloadUsesValueFromOperand(opOperand)) {
1607         Value operandVal = opOperand->get();
1608         auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
1609         if (!operandType)
1610           continue;
1611 
1612         // If outs is sparse, leave it to the sparse compiler.
1613         if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
1614           continue;
1615 
1616         // If outs is already an `init_tensor` operation, nothing to do.
1617         auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
1618         if (definingOp)
1619           continue;
1620         modifiedOutput = true;
1621         SmallVector<Value> dynamicDims;
1622         for (const auto &dim : llvm::enumerate(operandType.getShape())) {
1623           if (dim.value() != ShapedType::kDynamicSize)
1624             continue;
1625           dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
1626               loc, operandVal, dim.index()));
1627         }
1628         Value initTensor = rewriter.create<InitTensorOp>(
1629             loc, dynamicDims, operandType.getShape(),
1630             operandType.getElementType());
1631         op->setOperand(opOperand->getOperandNumber(), initTensor);
1632       }
1633     }
1634     if (!modifiedOutput) {
1635       rewriter.cancelRootUpdate(op);
1636       return failure();
1637     }
1638     rewriter.finalizeRootUpdate(op);
1639     return success();
1640   }
1641 };
1642 
1643 /// Fold linalg.fill into linalg.generic
1644 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
1645   using OpRewritePattern<GenericOp>::OpRewritePattern;
1646 
1647   LogicalResult matchAndRewrite(GenericOp genericOp,
1648                                 PatternRewriter &rewriter) const override {
1649     if (!genericOp.hasTensorSemantics())
1650       return failure();
1651     bool fillFound = false;
1652     Block &payload = genericOp.region().front();
1653     for (OpOperand *opOperand : genericOp.getInputOperands()) {
1654       if (!genericOp.payloadUsesValueFromOperand(opOperand))
1655         continue;
1656       FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
1657       if (!fillOp)
1658         continue;
1659       fillFound = true;
1660       payload.getArgument(opOperand->getOperandNumber())
1661           .replaceAllUsesWith(fillOp.value());
1662     }
1663     return success(fillFound);
1664   }
1665 };
1666 } // namespace
1667 
1668 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
1669     RewritePatternSet &patterns,
1670     const ControlFusionFn &controlFoldingReshapes) {
1671   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
1672                                                     controlFoldingReshapes);
1673   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
1674                                                      controlFoldingReshapes);
1675 }
1676 
1677 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
1678     RewritePatternSet &patterns,
1679     const ControlFusionFn &controlFoldingReshapes) {
1680   patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
1681                                                       controlFoldingReshapes);
1682 }
1683 
1684 void mlir::linalg::populateElementwiseOpsFusionPatterns(
1685     RewritePatternSet &patterns,
1686     const ControlFusionFn &controlElementwiseOpsFusion) {
1687   auto *context = patterns.getContext();
1688   patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
1689   patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
1690                RemoveOutsDependency>(context);
1691 }
1692 
1693 //===---------------------------------------------------------------------===//
1694 // Passes
1695 //===---------------------------------------------------------------------===//
1696 
1697 namespace {
1698 
1699 /// Pass that fuses generic ops on tensors. Used only for testing.
1700 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
1701 // patterns added here heavily depends on the cost function used. Having an
1702 // opinionated pass of this form is not recommended. Deprecate this pass in
1703 // favor of test passes that check the functionality of each of the patterns
1704 // added here individually.
1705 struct LinalgElementwiseOpFusionPass
1706     : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
1707   void runOnOperation() override {
1708     Operation *op = getOperation();
1709     MLIRContext *context = op->getContext();
1710     RewritePatternSet patterns(context);
1711 
1712     // Add folding with reshape by expansion patterns.
1713     ControlFusionFn defaultControlFn = [](const OpResult &producer,
1714                                           const OpOperand &consumer) {
1715       return producer.hasOneUse();
1716     };
1717 
1718     // Add elementwise op fusion patterns.
1719     populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
1720 
1721     populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
1722 
1723     // Add the sparse tensor rewriting patterns.
1724     populateSparseTensorRewriting(patterns);
1725 
1726     // General canonicalization patterns.
1727     AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1728     GenericOp::getCanonicalizationPatterns(patterns, context);
1729     tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1730     tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1731     context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1732         patterns);
1733 
1734     // Add constant folding patterns.
1735     populateConstantFoldLinalgOperations(patterns, defaultControlFn);
1736 
1737     // Use TopDownTraversal for compile time reasons
1738     GreedyRewriteConfig grc;
1739     grc.useTopDownTraversal = true;
1740     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
1741                                        grc);
1742   }
1743 };
1744 
1745 } // namespace
1746 
1747 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
1748   return std::make_unique<LinalgElementwiseOpFusionPass>();
1749 }
1750