1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include <utility>
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 //===---------------------------------------------------------------------===//
32 // Methods and patterns that fuse elementwise `linalg.generic` operations.
33 //===---------------------------------------------------------------------===//
34 
35 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
36 /// the `producer` to use in the fused operation given the indexing map of the
37 /// result of the producer in the consumer.
38 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
39     OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
40     AffineMap fusedConsumerArgIndexMap) {
41   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
42   // from consumer loop -> consumer arg tensor index/producer result tensor
43   // index. The fused loop is same as the consumer loop. For each producer arg
44   // the indexing map to be computed is a map from consumer loop -> producer
45   // arg tensor index.
46   // producerResultIndexMap is a map from producer loop -> tensor index.
47   // Compute the inverse to get map from tensor index -> producer loop.
48   // The inverse is a map from producer result tensor index -> producer loop.
49   AffineMap invProducerResultIndexMap =
50       inversePermutation(producerResultIndexMap);
51   assert(invProducerResultIndexMap &&
52          "expected producer result indexing map to be invertible");
53 
54   LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
55   // argMap is a map from producer loop -> producer arg tensor index.
56   AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
57 
58   // Compose argMap with invProducerResultIndexMap to get a map from
59   // producer result tensor index -> producer arg tensor index.
60   AffineMap t1 = argMap.compose(invProducerResultIndexMap);
61 
62   // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
63   // consumer loop/ fused loop -> producer arg tensor index.
64   return t1.compose(fusedConsumerArgIndexMap);
65 }
66 
67 /// Conditions for elementwise fusion of generic operations.
68 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
69                                      OpOperand *consumerOpOperand) {
70   // Producer and consumer must have tensor semantics.
71   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
72     return false;
73 
74   // Verify that
75   // - the producer has all "parallel" iterator type.
76   if (producer.getNumParallelLoops() != producer.getNumLoops())
77     return false;
78 
79   // Only allow fusing the producer of an input operand for now.
80   // TODO: allow fusing the producer of an output operand.
81   if (!consumer.isInputTensor(consumerOpOperand))
82     return false;
83 
84   // Get the consumer index map. The number of results of the consumer index
85   // map must match the number of loops of the producer.
86   AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
87   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
88     return false;
89 
90   // Currently support only operations with single result.
91   if (producer.getNumOutputs() != 1)
92     return false;
93 
94   // Finally the index_map for the result must be invertible. For now just
95   // verify it is a permutation.
96   AffineMap producerResultIndexMap =
97       producer.getTiedIndexingMap(producer.getOutputOperand(0));
98   if (!producerResultIndexMap.isPermutation())
99     return false;
100 
101   // Ensure that the fusion does not remove size information required to
102   // get the loop bounds. For non-reduction generics, this is trivially the
103   // case due to the output operand. For reductions, we need to check that after
104   // the fusion, each loop dimension has at least one input that defines it.
105   if ((consumer.getNumReductionLoops())) {
106     BitVector coveredDims(consumer.getNumLoops(), false);
107 
108     auto addToCoveredDims = [&](AffineMap map) {
109       for (auto result : map.getResults())
110         if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
111           coveredDims[dimExpr.getPosition()] = true;
112     };
113 
114     for (auto pair :
115          llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) {
116       Value operand = std::get<0>(pair);
117       if (operand == consumerOpOperand->get())
118         continue;
119       AffineMap operandMap = std::get<1>(pair);
120       addToCoveredDims(operandMap);
121     }
122 
123     for (OpOperand *operand : producer.getInputOperands()) {
124       AffineMap newIndexingMap =
125           getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
126               operand, producerResultIndexMap, consumerIndexMap);
127       addToCoveredDims(newIndexingMap);
128     }
129     if (!coveredDims.all())
130       return false;
131   }
132 
133   return true;
134 }
135 
136 /// Generate the region of the fused tensor operation. The region of the fused
137 /// op must be empty.
138 static void
139 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
140                                  AffineMap consumerToProducerLoopsMap,
141                                  OpOperand *consumerOpOperand,
142                                  unsigned nloops) {
143   auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
144   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
145   // Build the region of the fused op.
146   Block &producerBlock = producer->getRegion(0).front();
147   Block &consumerBlock = consumer->getRegion(0).front();
148   Block *fusedBlock = new Block();
149   fusedOp.region().push_back(fusedBlock);
150   BlockAndValueMapping mapper;
151   OpBuilder::InsertionGuard guard(rewriter);
152   rewriter.setInsertionPointToStart(fusedBlock);
153 
154   // 2. Add an index operation for every fused loop dimension and use the
155   // `consumerToProducerLoopsMap` to map the producer indices.
156   if (producer.hasIndexSemantics()) {
157     // Add an index operation for every fused loop dimension.
158     unsigned numFusedOpLoops =
159         std::max(producer.getNumLoops(), consumer.getNumLoops());
160     SmallVector<Value> fusedIndices;
161     fusedIndices.reserve(numFusedOpLoops);
162     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
163                     std::back_inserter(fusedIndices), [&](uint64_t dim) {
164                       return rewriter.create<IndexOp>(producer.getLoc(), dim);
165                     });
166     for (IndexOp indexOp :
167          llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
168       Value newIndex = rewriter.create<mlir::AffineApplyOp>(
169           producer.getLoc(),
170           consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
171       mapper.map(indexOp.getResult(), newIndex);
172     }
173   }
174   // TODO: allow fusing the producer of an output operand.
175   assert(consumer.isInputTensor(consumerOpOperand) &&
176          "expected producer of input operand");
177   // 3. Consumer input operands up to consumerIdx (exclusive).
178   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
179            consumerOpOperand->getOperandNumber())) // input assumption.
180     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
181 
182   // Replacing consumerIdx requires getting the cloned, yielded, value from
183   // the (cloned) producer block. This happens in step 9.
184 
185   // 4. Splice in producer's input operands.
186   for (BlockArgument bbArg :
187        producerBlock.getArguments().take_front(producer.getNumInputs()))
188     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
189 
190   // 4.b. Producer output operand/map that is fused needs to be mapped to the
191   // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
192   assert(producer->getNumResults() == 1 && "expected single result producer");
193   if (producer.isInitTensor(producer.getOutputOperand(0))) {
194     BlockArgument bbArg = producerBlock.getArguments()
195                               .drop_front(producer.getNumInputs())
196                               // TODO: bbArg index of
197                               .front();
198     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
199   }
200   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
201   for (BlockArgument bbArg :
202        consumerBlock.getArguments()
203            .take_front(consumer.getNumInputs())
204            .drop_front(consumerOpOperand->getOperandNumber() + 1))
205     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
206   // 6. All of consumer's output operands.
207   for (BlockArgument bbArg :
208        consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
209     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
210   // 7. All of producer's output operands except the one fused.
211   // TODO: allow fusion of multi-result producers.
212   assert(producer->getNumResults() == 1 && "expected single result producer");
213 
214   // 8. Clone all producer operations except for the yield and index operations
215   // to the fused operation.
216   for (auto &op : producerBlock.without_terminator()) {
217     if (!isa<IndexOp>(op))
218       rewriter.clone(op, mapper);
219   }
220   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
221   // forward the yield operand.
222   auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
223   // TODO: allow fusion of multi-result producers.
224   assert(producer->getNumResults() == 1 && "expected single result producer");
225   unsigned producerResultNumber = 0;
226   Value replacement =
227       mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
228   // Sanity checks, if replacement is not already in the mapper then it must be
229   // produced outside.
230   if (replacement == yieldOp.getOperand(producerResultNumber)) {
231     if (auto bb = replacement.dyn_cast<BlockArgument>())
232       assert(bb.getOwner() != &producerBlock &&
233              "yielded block argument must have been mapped");
234     else
235       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
236              "yielded value must have been mapped");
237   }
238   mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
239              replacement);
240   // 10. Clone operations from the consumer to the fused op.
241   for (auto &op : consumerBlock.getOperations())
242     rewriter.clone(op, mapper);
243 
244   // Sanity checks.
245   assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
246          "Ill-formed GenericOp region");
247 }
248 
249 static Optional<SmallVector<Value>>
250 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
251                        const ControlFusionFn &controlFn,
252                        PatternRewriter &rewriter) {
253   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
254   if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
255       !controlFn(producer->getResult(0), *consumerOpOperand))
256     return llvm::None;
257 
258   // TODO: allow fusing the producer of an output operand.
259   assert(consumer.isInputTensor(consumerOpOperand) &&
260          "expected producer of input operand");
261 
262   // Compute the fused operands list and indexing maps.
263   SmallVector<Value> fusedOperands;
264   SmallVector<AffineMap> fusedIndexMaps;
265   fusedOperands.reserve(producer->getNumOperands() +
266                         consumer->getNumOperands());
267   fusedIndexMaps.reserve(producer->getNumOperands() +
268                          consumer->getNumOperands());
269   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
270   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
271   SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
272   SmallVector<OpOperand *>::iterator it =
273       llvm::find(consumerInputs, consumerOpOperand);
274   assert(it != consumerInputs.end() && "expected to find the consumer operand");
275   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
276     fusedOperands.push_back(opOperand->get());
277     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
278   }
279   // 4. Splice in producer's input operands/maps.
280   assert(producer->getNumResults() == 1 && "expected single result producer");
281   AffineMap producerResultIndexMap =
282       producer.getTiedIndexingMap(producer.getOutputOperand(0));
283   for (OpOperand *opOperand : producer.getInputOperands()) {
284     fusedOperands.push_back(opOperand->get());
285     // Compute indexing maps for the producer args in the fused operation.
286     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
287         opOperand, producerResultIndexMap,
288         consumer.getTiedIndexingMap(consumerOpOperand));
289     fusedIndexMaps.push_back(map);
290   }
291   // 4.b. Producer output operand/map that is fused needs to be passed if it is
292   // an "initTensor" (i.e. its value is actually read).
293   assert(producer->getNumResults() == 1 && "expected single result producer");
294   if (producer.isInitTensor(producer.getOutputOperand(0))) {
295     fusedOperands.push_back(producer.getOutputOperand(0)->get());
296     // Compute indexing maps for the producer args in the fused operation.
297     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
298         producer.getOutputOperand(0), producerResultIndexMap,
299         consumer.getTiedIndexingMap(consumerOpOperand));
300     fusedIndexMaps.push_back(map);
301   }
302   // 5. Remaining consumer's input operands/maps (drop past index
303   // `consumerIdx`).
304   for (OpOperand *opOperand :
305        llvm::make_range(std::next(it), consumerInputs.end())) {
306     fusedOperands.push_back(opOperand->get());
307     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
308   }
309   // 6. All of consumer's output operands (skip operands: added by the builder).
310   for (OpOperand *opOperand : consumer.getOutputOperands())
311     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
312   // 7. All of producer's output operands/maps except the one fused.
313   // TODO: allow fusion of multi-result producers.
314   assert(producer->getNumResults() == 1 && "expected single result producer");
315 
316   // Generate the fused op.
317   SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
318   auto fusedOp = rewriter.create<GenericOp>(
319       consumer.getLoc(), consumer->getResultTypes(),
320       /*inputs=*/fusedOperands,
321       // TODO: handle outputs.
322       consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
323       consumer.iterator_types(),
324       /*doc=*/nullptr,
325       /*library_call=*/nullptr);
326   if (!fusedOp.getShapesToLoopsMap()) {
327     // Fused op has invalid indexing maps. Typically this means something is off
328     // in the input, but going ahead here would result in verification errors.
329     // So cleanup and abort.
330     rewriter.eraseOp(fusedOp);
331     return llvm::None;
332   }
333 
334   // Construct an AffineMap from consumer loops to producer loops.
335   // consumer loop -> tensor index
336   AffineMap consumerResultIndexMap =
337       consumer.getTiedIndexingMap(consumerOpOperand);
338   // tensor index -> producer loop
339   AffineMap invProducerResultIndexMap =
340       inversePermutation(producerResultIndexMap);
341   assert(invProducerResultIndexMap &&
342          "expected producer result indexig map to be invertible");
343   // consumer loop -> producer loop
344   AffineMap consumerToProducerLoopsMap =
345       invProducerResultIndexMap.compose(consumerResultIndexMap);
346 
347   generateFusedElementwiseOpRegion(rewriter, fusedOp,
348                                    consumerToProducerLoopsMap,
349                                    consumerOpOperand, consumer.getNumLoops());
350   return SmallVector<Value>(fusedOp->getResults());
351 }
352 
353 static Optional<SmallVector<Value>>
354 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
355                    GenericOp producer, const ControlFusionFn &controlFn) {
356   if (producer->getNumResults() != 1)
357     return llvm::None;
358 
359   return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
360                                 rewriter);
361 }
362 
363 namespace {
364 /// Patterns to fuse a generic op, with the producer of its operands.
365 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
366 public:
367   FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
368                      PatternBenefit benefit = 1)
369       : OpRewritePattern<GenericOp>(context, benefit),
370         controlFn(std::move(fun)) {}
371 
372   LogicalResult matchAndRewrite(GenericOp genericOp,
373                                 PatternRewriter &rewriter) const override {
374     // Find the first operand that is defined by another generic op on tensors.
375     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
376       auto producer =
377           dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
378       if (!producer || !producer.hasTensorSemantics())
379         continue;
380       Optional<SmallVector<Value>> fusedOpResults =
381           fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
382       if (fusedOpResults) {
383         rewriter.replaceOp(genericOp, *fusedOpResults);
384         return success();
385       }
386     }
387     return failure();
388   }
389 
390 private:
391   ControlFusionFn controlFn;
392 };
393 } // namespace
394 
395 //===---------------------------------------------------------------------===//
396 // Methods and patterns that fuse reshape ops with elementwise operations by
397 // linearization of indexing maps.
398 //===---------------------------------------------------------------------===//
399 
400 // TODO(ravishankarm): The indexing maps
401 // these produce in the general case are detrimental to transformations.
402 // These patterns are on deprecation path in favor of using fusion by
403 // collapsing, which covers the only legitimate use case of this pattern of
404 // folding unit-extent dims.
405 
406 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
407 /// provided, given the shape of the source tensor that corresponds to the
408 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
409 /// are "row-major" ordered logically.
410 ///
411 /// For example:
412 ///
413 /// %0 = op ... : tensor<?x?x4x5xf32>
414 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
415 ///
416 /// and reshape:
417 /// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] :
418 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
419 ///
420 /// would be rewritten into:
421 /// %0 = op ... : tensor<?x?x4x5xf32>
422 /// with output index_map
423 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
424 template <typename TensorReshapeOp>
425 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
426                                         TensorReshapeOp reshapeOp) {
427   constexpr bool isExpanding =
428       std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
429   ArrayRef<int64_t> sourceShape =
430       (isExpanding ? reshapeOp.getResultType().getShape()
431                    : reshapeOp.getSrcType().getShape());
432   SmallVector<AffineExpr> resultExprs;
433   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
434   MLIRContext *context = sourceMap.getContext();
435 
436   // Compute the result exprs based on the reassociation maps.
437   for (auto &indices : reshapeOp.getReassociationIndices()) {
438     // Assume that they are in-order and contiguous (already checked in
439     // verifier).
440     assert(!indices.empty());
441     SmallVector<int64_t> sizes;
442     SmallVector<AffineExpr> dimExprs;
443     for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
444                              sourceExprs.slice(indices[0], indices.size()))) {
445       if (std::get<0>(en) == 1)
446         continue;
447       sizes.push_back(std::get<0>(en));
448       dimExprs.push_back(std::get<1>(en));
449     }
450     AffineExpr linearizedExpr =
451         makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
452     resultExprs.push_back(linearizedExpr);
453   }
454   // The new affine map cannot drop unused dimension but some new symbols may
455   // have been added. Create a map with at least as many dimensions/symbols as
456   // the original affine map.
457   int64_t maxDim = -1;
458   int64_t maxSym = -1;
459   getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym);
460   unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims());
461   unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols());
462   return AffineMap::get(numDims, numSyms, resultExprs, context);
463 }
464 
465 // tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a
466 // producer). Fusing when operand has higher rank will require use of mods and
467 // divs in the indexing maps of the fused op which would make it non-invertible.
468 static bool isTensorReshapeOpFoldableByLinearization(
469     tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
470   if (!asProducer)
471     return false;
472   return useIndexMap.isPermutation();
473 }
474 
475 // tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a
476 // consumer).
477 static bool
478 isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp,
479                                          AffineMap useIndexMap,
480                                          bool asProducer) {
481   if (asProducer)
482     return false;
483   return useIndexMap.isPermutation();
484 }
485 
486 /// Check if the reshape operation is only expansion into/collapsing of
487 /// unit-dimension.
488 template <typename TensorReshapeOp>
489 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
490   constexpr bool isExpanding =
491       std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
492   ArrayRef<int64_t> expandedShape =
493       (isExpanding ? reshapeOp.getResultType().getShape()
494                    : reshapeOp.getSrcType().getShape());
495   for (auto &indices : reshapeOp.getReassociationIndices()) {
496     unsigned numUnitDims = 0;
497     for (int64_t position : indices)
498       if (expandedShape[position] == 1)
499         numUnitDims++;
500     if (numUnitDims != indices.size() - 1)
501       return false;
502   }
503   return true;
504 }
505 
506 namespace {
507 /// Pattern to fold tensor_expand_shape op with its consumer by using the source
508 /// of the reshape op as the operand in the consumer (instead of the result of
509 /// the tensor_collapse_shape). The corresponding index map in the consumer
510 /// needs to be modified to linearize the folded dimension.
511 ///
512 /// For example,
513 ///
514 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
515 /// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
516 ///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
517 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
518 ///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
519 ///        -> tensor<?x?x4x?xf32>
520 ///
521 /// can be folded into
522 ///
523 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
524 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
525 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
526 ///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
527 ///        -> tensor<?x?x4x?xf32>
528 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
529 struct FoldProducerReshapeOpByLinearization
530     : public OpRewritePattern<GenericOp> {
531   using OpRewritePattern<GenericOp>::OpRewritePattern;
532 
533   LogicalResult matchAndRewrite(GenericOp genericOp,
534                                 PatternRewriter &rewriter) const override {
535     if (!genericOp.hasTensorSemantics())
536       return failure();
537     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
538     for (const auto &en : llvm::enumerate(inputOperands)) {
539       auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
540       if (!reshapeOp)
541         continue;
542 
543       if (!isTensorReshapeOpFoldableByLinearization(
544               reshapeOp, genericOp.getTiedIndexingMap(en.value()),
545               /*asProducer =*/true) ||
546           (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
547         continue;
548 
549       // Compute the fused operands list,
550       SmallVector<Value> fusedOperands = genericOp.getInputOperands();
551       fusedOperands[en.index()] = reshapeOp.src();
552       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
553       llvm::append_range(fusedOperands, outputOperands);
554 
555       // Compute indexing_maps for the fused operation. The indexing_maps for
556       // the operands of the consumers that arent fused are the same.
557       SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
558 
559       // Compute the indexing map to use for the result of the producer.
560       AffineMap modifiedMap =
561           linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
562       // The modified map cannot have symbols.
563       if (modifiedMap.getNumSymbols())
564         return failure();
565       for (AffineExpr expr : modifiedMap.getResults()) {
566         if (!expr.isPureAffine())
567           return failure();
568       }
569       fusedIndexMaps[en.index()] = modifiedMap;
570 
571       // Further check that the resulting index maps can be fused and
572       // inverted. Without this the resultant op is not legal.
573       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
574         return rewriter.notifyMatchFailure(
575             genericOp, "fused op loop bound computation failed");
576       }
577 
578       rewriter.startRootUpdate(genericOp);
579       genericOp->setOperands(fusedOperands);
580       genericOp.indexing_mapsAttr(
581           rewriter.getAffineMapArrayAttr(fusedIndexMaps));
582       rewriter.finalizeRootUpdate(genericOp);
583       return success();
584     }
585     return failure();
586   }
587 };
588 
589 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
590 /// producer. The corresponding index map in the consumer needs to be modified
591 /// to linearize the folded dimension.
592 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
593 struct FoldConsumerReshapeOpByLinearization
594     : public OpRewritePattern<TensorReshapeOp> {
595   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
596 
597   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
598                                 PatternRewriter &rewriter) const override {
599     GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
600     if (!producer || !producer.hasTensorSemantics() ||
601         producer.getNumOutputs() != 1 ||
602         !isTensorReshapeOpFoldableByLinearization(
603             reshapeOp,
604             producer.getTiedIndexingMap(producer.getOutputOperand(0)),
605             /*asProducer =*/false) ||
606         (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
607       return failure();
608     // The indexing_maps for the operands of the fused operation are same as
609     // those for the operands of the producer.
610     SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
611 
612     // Compute the indexing map to use for the operand of the producer.
613     AffineMap modifiedMap = linearizeCollapsedDims(
614         producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
615     for (AffineExpr expr : modifiedMap.getResults()) {
616       if (!expr.isPureAffine()) {
617         return rewriter.notifyMatchFailure(
618             producer, "fused op indexing map is not affine");
619       }
620     }
621     fusedIndexMaps.back() = modifiedMap;
622 
623     // Further check that the resulting index maps can be fused and
624     // inverted. Without this the resultant op is not legal.
625     if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
626       return rewriter.notifyMatchFailure(
627           producer, "fused op loop bound computation failed");
628     }
629 
630     Location loc = producer.getLoc();
631     SmallVector<Value> inputOperands = producer.getInputOperands();
632     Value output = rewriter.create<TensorReshapeOp>(
633         loc, producer.getOutputOperand(0)->get(),
634         reshapeOp.getReassociationExprs());
635     auto fusedOp = rewriter.create<GenericOp>(
636         loc, reshapeOp.getResultType(),
637         /*inputs=*/inputOperands,
638         // TODO: handle outputs.
639         /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
640         producer.iterator_types(),
641         /*doc=*/nullptr,
642         /*library_call=*/nullptr);
643     auto &fusedRegion = fusedOp->getRegion(0);
644     rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
645                                fusedRegion.begin());
646     rewriter.replaceOp(reshapeOp, fusedOp->getResults());
647     return success();
648   }
649 };
650 } // namespace
651 
652 //===---------------------------------------------------------------------===//
653 // Methods and patterns that fuse reshape ops with elementwise operations by
654 // expanding the dimensionality of the elementwise operations.
655 //===---------------------------------------------------------------------===//
656 
657 /// Conditions for folding a generic operation with a reshape op by expanding
658 /// the iteration space dimensionality for tensor operations. These are
659 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
660 /// following fusion pattern.
661 ///
662 ///  Consider
663 ///
664 ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
665 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
666 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
667 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
668 ///  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
669 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
670 ///
671 ///  The reshape can be folded into the `genericOp` if its loop dimensionality
672 ///  is increased to match the result (operand) of the tensor_expand_shape.
673 ///  The indexing_map of the fused tensor in the `genericOp` and the
674 ///  reassociation map helps compute the indexing maps of the modified op.
675 ///  For the above example, based on the reassociation map it
676 ///  can be concluded that
677 ///
678 ///  - The loop used to access the first dimension of the fused tensor is split
679 ///    into two.
680 ///  - The loop used to access the second dimension of the fused tensor is kept
681 ///    as is.
682 ///  - The loop used to access the third dimension of the fused tensor is split
683 ///    into three.
684 ///
685 ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
686 ///  op, then
687 ///
688 ///   d0 -> e0, e1
689 ///   d1 -> e2, e3, e4
690 ///   d2 -> e5
691 ///
692 ///  substituting this, the generic op can be rewritten as
693 ///
694 ///  %d = linalg.generic ins(%0, %1 : )
695 ///        indexing_maps =
696 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
697 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
698 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
699 ///
700 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
701 ///  to make it consistent
702 ///
703 ///  %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
704 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
705 ///  %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
706 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
707 ///
708 ///  The added reshapes are again expanding patterns, so they will get fused
709 ///  with its producers if possible.
710 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
711                                                OpOperand *fusableOpOperand) {
712   // Is fusable only if:
713   // - All the indexing maps for operands and results are projected
714   //   permutations.
715   // - The fused tensor is not a scalar.
716   // - All the loops are parallel loops.
717   return genericOp.hasTensorSemantics() &&
718          llvm::all_of(genericOp.indexing_maps().getValue(),
719                       [](Attribute attr) {
720                         return attr.cast<AffineMapAttr>()
721                             .getValue()
722                             .isProjectedPermutation();
723                       }) &&
724          genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
725          llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
726            return attr.cast<StringAttr>().getValue() ==
727                   getParallelIteratorTypeName();
728          });
729 }
730 
731 namespace {
732 /// Information needed to expand a generic operation to fold the reshape with
733 /// it.
734 class ExpansionInfo {
735 public:
736   // Computes the mapping from original dimensions of the op to the dimensions
737   // of the expanded op given the `indexingMap` of the fused operand/result of
738   // the generic op, the `reassocationMaps` of the reshape op and the shape of
739   // the expanded op.
740   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
741                         ArrayRef<AffineMap> reassociationMaps,
742                         ArrayRef<int64_t> expandedShape,
743                         ArrayRef<int64_t> collapsedShape,
744                         PatternRewriter &rewriter);
745   unsigned getOrigOpNumDims() const { return reassociation.size(); }
746   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
747   ReassociationIndicesRef getExpandedDims(unsigned i) const {
748     return reassociation[i];
749   }
750   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
751     return expandedShapeMap[i];
752   }
753   ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
754 
755 private:
756   /// Reassociation from the dimensions in the original operation to the
757   /// dimension of the expanded operation.
758   SmallVector<ReassociationIndices> reassociation;
759   /// Mapping from extent of loops in the original operation, to the extent of
760   /// loops in the expanded operation.
761   SmallVector<SmallVector<int64_t>> expandedShapeMap;
762   /// Extent of the loop in the original operation.
763   SmallVector<int64_t> originalLoopExtent;
764   unsigned expandedOpNumDims;
765 };
766 } // namespace
767 
768 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
769                                      OpOperand *fusableOpOperand,
770                                      ArrayRef<AffineMap> reassociationMaps,
771                                      ArrayRef<int64_t> expandedShape,
772                                      ArrayRef<int64_t> collapsedShape,
773                                      PatternRewriter &rewriter) {
774   if (reassociationMaps.empty())
775     return failure();
776   AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
777 
778   Optional<SmallVector<int64_t, 4>> originalLoopRange =
779       linalgOp.getStaticLoopRanges();
780   if (!originalLoopRange)
781     return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
782   originalLoopExtent.assign(originalLoopRange->begin(),
783                             originalLoopRange->end());
784 
785   reassociation.clear();
786   expandedShapeMap.clear();
787   // Compute the number of dimension in the expanded op that correspond to each
788   // dimension of the original op.
789   SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
790   expandedShapeMap.resize(fusedIndexMap.getNumDims());
791   for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
792     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
793     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
794     numExpandedDims[pos] = foldedDims.getNumResults();
795     ArrayRef<int64_t> shape =
796         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
797     expandedShapeMap[pos].assign(shape.begin(), shape.end());
798   }
799   // The remaining dimensions remain the same.
800   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
801     if (expandedShapeMap[i].empty())
802       expandedShapeMap[i] = {originalLoopExtent[i]};
803 
804   // Compute reassociation map from the original op to the expanded op.
805   unsigned sum = 0;
806   reassociation.reserve(fusedIndexMap.getNumDims());
807   for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
808     auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
809     reassociation.emplace_back(seq.begin(), seq.end());
810     sum += numFoldedDim.value();
811   }
812   expandedOpNumDims = sum;
813   return success();
814 }
815 
816 /// Epanding the body of a linalg operation requires adaptations of the accessed
817 /// loop indices. Specifically, access of indices in the original operation need
818 /// to be replaced with linearizations of indices in the expanded op. That
819 /// requires the shape of the expanded dimensions to be static (at least all but
820 /// the most significant). For now check that these are all statically sized.
821 /// Note that this could be extended to handle dynamic case, but the
822 /// implementation below uses `affine.apply` which seems to have issues when the
823 /// shapes are not static.
824 static LogicalResult isGenericOpExpandable(GenericOp genericOp,
825                                            const ExpansionInfo &expansionInfo,
826                                            PatternRewriter &rewriter) {
827   if (!genericOp.hasIndexSemantics())
828     return success();
829   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
830     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
831     if (expandedShape.size() == 1)
832       continue;
833     for (int64_t shape : expandedShape.drop_front()) {
834       if (ShapedType::isDynamic(shape)) {
835         return rewriter.notifyMatchFailure(
836             genericOp, "cannot expand due to index semantics and dynamic dims");
837       }
838     }
839   }
840   return success();
841 }
842 
843 /// Return the indexing map to use in the expanded op for a given the
844 /// `indexingMap` of the original operation.
845 static AffineMap
846 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
847                            const ExpansionInfo &expansionInfo) {
848   SmallVector<AffineExpr> newExprs;
849   for (AffineExpr expr : indexingMap.getResults()) {
850     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
851     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
852         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
853           return builder.getAffineDimExpr(static_cast<unsigned>(v));
854         }));
855     newExprs.append(expandedExprs.begin(), expandedExprs.end());
856   }
857   return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
858                         indexingMap.getNumSymbols(), newExprs,
859                         builder.getContext());
860 }
861 
862 /// Return the type of the operand/result to use in the expanded op given the
863 /// type in the original op.
864 static RankedTensorType getExpandedType(RankedTensorType originalType,
865                                         AffineMap indexingMap,
866                                         const ExpansionInfo &expansionInfo) {
867   SmallVector<int64_t> expandedShape;
868   for (AffineExpr expr : indexingMap.getResults()) {
869     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
870     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
871     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
872   }
873   return RankedTensorType::get(expandedShape, originalType.getElementType());
874 }
875 
876 /// Returns the reassociation maps to use in the `tensor.expand_shape`
877 /// operation to convert the operands of the original operation to operands of
878 /// the expanded operation. The same method is used to compute the
879 /// `tensor.collapse_shape` used to collapse the result of the expanded
880 /// op to get the value that can replace all uses of the results of the original
881 /// op.
882 static SmallVector<ReassociationIndices>
883 getReassociationForExpansion(AffineMap indexingMap,
884                              const ExpansionInfo &expansionInfo) {
885   SmallVector<ReassociationIndices> reassociation;
886   unsigned numReshapeDims = 0;
887   for (AffineExpr expr : indexingMap.getResults()) {
888     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
889     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
890     SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
891         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
892     reassociation.emplace_back(std::move(indices));
893     numReshapeDims += numExpandedDims;
894   }
895   return reassociation;
896 }
897 
898 /// Update the body of an expanded linalg operation having index semantics. The
899 /// indices of the original operation need to be recovered by linearizing the
900 /// indices of the correspoding dimensions of the expanded operation. For now it
901 /// is assumed that the shapes of the expanded operation needed for
902 /// linearization are static.
903 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
904                                           Location loc, Region &fusedRegion,
905                                           const ExpansionInfo &expansionInfo) {
906   // Replace the original indices by the linearization of the expanded indices.
907   for (IndexOp indexOp :
908        llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
909     ArrayRef<int64_t> expandedDims =
910         expansionInfo.getExpandedDims(indexOp.dim());
911     assert(!expandedDims.empty() && "expected valid expansion info");
912 
913     // Skip index operations that are not affected by the expansion.
914     if (expandedDims.size() == 1 &&
915         expandedDims.front() == (int64_t)indexOp.dim())
916       continue;
917 
918     // Linearize the expanded indices of the original index dimension.
919     OpBuilder::InsertionGuard guard(rewriter);
920     rewriter.setInsertionPointAfter(indexOp);
921     ArrayRef<int64_t> expandedDimsShape =
922         expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
923     SmallVector<Value> expandedIndices;
924     expandedIndices.reserve(expandedDims.size() - 1);
925     llvm::transform(
926         expandedDims.drop_front(), std::back_inserter(expandedIndices),
927         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
928     Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
929     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
930       assert(!ShapedType::isDynamic(std::get<0>(it)));
931       AffineExpr idx, acc;
932       bindDims(rewriter.getContext(), idx, acc);
933       newIndex = rewriter.create<AffineApplyOp>(
934           indexOp.getLoc(), idx + acc * std::get<0>(it),
935           ValueRange{std::get<1>(it), newIndex});
936     }
937     rewriter.replaceOp(indexOp, newIndex);
938   }
939 }
940 
941 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op
942 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
943 /// that those conditions have been satisfied.
944 static Optional<SmallVector<Value>>
945 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
946                            OpOperand *fusableOpOperand,
947                            PatternRewriter &rewriter) {
948   assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
949          "preconditions for fuse operation failed");
950   // Check if reshape is expanding or collapsing.
951   auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
952   auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
953   bool isExpanding = (expandingReshapeOp != nullptr);
954   RankedTensorType expandedType = isExpanding
955                                       ? expandingReshapeOp.getResultType()
956                                       : collapsingReshapeOp.getSrcType();
957   RankedTensorType collapsedType = isExpanding
958                                        ? expandingReshapeOp.getSrcType()
959                                        : collapsingReshapeOp.getResultType();
960 
961   ExpansionInfo expansionInfo;
962   if (failed(expansionInfo.compute(
963           genericOp, fusableOpOperand,
964           isExpanding ? expandingReshapeOp.getReassociationMaps()
965                       : collapsingReshapeOp.getReassociationMaps(),
966           expandedType.getShape(), collapsedType.getShape(), rewriter)))
967     return llvm::None;
968 
969   if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
970     return llvm::None;
971 
972   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
973       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
974         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
975       }));
976 
977   SmallVector<Value> expandedOpOperands;
978   expandedOpOperands.reserve(genericOp.getNumInputs());
979   for (OpOperand *opOperand : genericOp.getInputOperands()) {
980     if (opOperand == fusableOpOperand) {
981       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
982                                                : collapsingReshapeOp.src());
983       continue;
984     }
985     if (genericOp.isInputTensor(opOperand)) {
986       AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
987       auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
988       RankedTensorType expandedOperandType =
989           getExpandedType(opOperandType, indexingMap, expansionInfo);
990       if (expandedOperandType != opOperand->get().getType()) {
991         // Reshape the operand to get the right type.
992         SmallVector<ReassociationIndices> reassociation =
993             getReassociationForExpansion(indexingMap, expansionInfo);
994         if (failed(reshapeLikeShapesAreCompatible(
995                 [&](const Twine &msg) {
996                   return rewriter.notifyMatchFailure(genericOp, msg);
997                 },
998                 opOperandType.getShape(), expandedOperandType.getShape(),
999                 reassociation,
1000                 /*isExpandingReshape=*/true)))
1001           return llvm::None;
1002         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
1003             genericOp.getLoc(), expandedOperandType, opOperand->get(),
1004             reassociation));
1005         continue;
1006       }
1007     }
1008     expandedOpOperands.push_back(opOperand->get());
1009   }
1010 
1011   Location loc = genericOp.getLoc();
1012   SmallVector<Value> outputs;
1013   for (OpOperand *opOperand : genericOp.getOutputOperands()) {
1014     AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
1015     auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
1016     RankedTensorType expandedOutputType =
1017         getExpandedType(opOperandType, indexingMap, expansionInfo);
1018     if (expandedOutputType != opOperand->get().getType()) {
1019       SmallVector<ReassociationIndices> reassociation =
1020           getReassociationForExpansion(indexingMap, expansionInfo);
1021       if (failed(reshapeLikeShapesAreCompatible(
1022               [&](const Twine &msg) {
1023                 return rewriter.notifyMatchFailure(genericOp, msg);
1024               },
1025               opOperandType.getShape(), expandedOutputType.getShape(),
1026               reassociation,
1027               /*isExpandingReshape=*/true)))
1028         return llvm::None;
1029       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
1030           genericOp.getLoc(), expandedOutputType, opOperand->get(),
1031           reassociation));
1032     }
1033   }
1034 
1035   // The iterator types of the expanded op are all parallel.
1036   SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
1037                                        getParallelIteratorTypeName());
1038 
1039   TypeRange resultTypes = ValueRange(outputs).getTypes();
1040   auto fusedOp =
1041       rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
1042                                  /*inputs=*/expandedOpOperands, outputs,
1043                                  expandedOpIndexingMaps, iteratorTypes);
1044   Region &fusedRegion = fusedOp->getRegion(0);
1045   Region &originalRegion = genericOp->getRegion(0);
1046   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
1047 
1048   // Update the index accesses after the expansion.
1049   updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
1050 
1051   // Reshape the result values to their original shape if this is a collapsing
1052   // reshape folded into its consumer.
1053   SmallVector<Value> resultVals;
1054   for (OpResult opResult : genericOp->getOpResults()) {
1055     int64_t resultNumber = opResult.getResultNumber();
1056     if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
1057       SmallVector<ReassociationIndices> reassociation =
1058           getReassociationForExpansion(
1059               genericOp.getTiedIndexingMap(
1060                   genericOp.getOutputOperand(resultNumber)),
1061               expansionInfo);
1062       resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
1063           genericOp.getLoc(), opResult.getType(),
1064           fusedOp->getResult(resultNumber), reassociation));
1065     } else {
1066       resultVals.push_back(fusedOp->getResult(resultNumber));
1067     }
1068   }
1069   // Assuming a single result.
1070   return resultVals;
1071 }
1072 
1073 namespace {
1074 
1075 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
1076 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
1077 /// in the consumer is expanded.
1078 class FoldWithProducerReshapeOpByExpansion
1079     : public OpRewritePattern<GenericOp> {
1080 public:
1081   FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1082                                        ControlFusionFn foldReshapes,
1083                                        PatternBenefit benefit = 1)
1084       : OpRewritePattern<GenericOp>(context, benefit),
1085         controlFoldingReshapes(std::move(foldReshapes)) {}
1086 
1087   LogicalResult matchAndRewrite(GenericOp genericOp,
1088                                 PatternRewriter &rewriter) const override {
1089     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1090       tensor::CollapseShapeOp reshapeOp =
1091           opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1092       if (!reshapeOp)
1093         continue;
1094       // Fold only if
1095       // - The tensor reshape op is folding.
1096       // - All constraints of fusing with reshape by expansion are met.
1097       if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
1098           (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
1099         continue;
1100 
1101       Optional<SmallVector<Value>> replacementValues =
1102           fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
1103       if (!replacementValues)
1104         return failure();
1105       rewriter.replaceOp(genericOp, replacementValues.getValue());
1106       return success();
1107     }
1108     return failure();
1109   }
1110 
1111 private:
1112   ControlFusionFn controlFoldingReshapes;
1113 };
1114 
1115 /// Pattern to fold a tensor_expand_shape op with its producer generic op
1116 /// by expanding the dimensionality of the loop in the producer op.
1117 struct FoldReshapeWithGenericOpByExpansion
1118     : public OpRewritePattern<tensor::ExpandShapeOp> {
1119 
1120   FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1121                                       ControlFusionFn foldReshapes,
1122                                       PatternBenefit benefit = 1)
1123       : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1124         controlFoldingReshapes(std::move(foldReshapes)) {}
1125 
1126   LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1127                                 PatternRewriter &rewriter) const override {
1128     // Fold only if all constraints of fusing with reshape by expansion are met.
1129     GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
1130     if (!producer || producer.getNumOutputs() != 1 ||
1131         !isFusableWithReshapeByDimExpansion(producer,
1132                                             producer.getOutputOperand(0)) ||
1133         !controlFoldingReshapes(producer->getResult(0),
1134                                 reshapeOp->getOpOperand(0)))
1135       return failure();
1136     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
1137         producer, reshapeOp, producer.getOutputOperand(0), rewriter);
1138     if (!replacementValues)
1139       return failure();
1140     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
1141     return success();
1142   }
1143 
1144 private:
1145   ControlFusionFn controlFoldingReshapes;
1146 };
1147 } // namespace
1148 
1149 //===---------------------------------------------------------------------===//
1150 // Methods and patterns to fuse reshape with linalg.generic operations by
1151 // contraction of dimensions.
1152 //===---------------------------------------------------------------------===//
1153 
1154 /// For a given list of indices in the range of the `indexingMap` that are
1155 /// folded, return the indices of the corresponding domain. Return `llvm::None`
1156 /// on failure. Ensures that all the elements of the returned reassociation are
1157 /// distinct.
1158 static ReassociationIndices
1159 getDomainReassociation(AffineMap indexingMap,
1160                        ReassociationIndicesRef rangeReassociation) {
1161   assert(indexingMap.isProjectedPermutation() &&
1162          "expected projected permutation");
1163 
1164   ReassociationIndices domainReassociation = llvm::to_vector<4>(
1165       llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1166         return indexingMap.getResults()[pos]
1167             .cast<AffineDimExpr>()
1168             .getPosition();
1169       }));
1170   // The projected permutation semantics ensures that there is no repetition of
1171   // the domain indices.
1172   return domainReassociation;
1173 }
1174 
1175 /// For a given `dimSequence`, check if the sequence is conserved in the
1176 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1177 /// Non-existence of the sequence returns true as well.
1178 static bool isDimSequencePreserved(AffineMap indexingMap,
1179                                    ReassociationIndicesRef dimSequence) {
1180   assert(!dimSequence.empty() &&
1181          "expected non-empty list for dimension sequence");
1182   assert(indexingMap.isProjectedPermutation() &&
1183          "expected indexing map to be projected permutation");
1184 
1185   llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1186   sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1187 
1188   unsigned dimSequenceStart = dimSequence[0];
1189   for (const auto &expr : enumerate(indexingMap.getResults())) {
1190     unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition();
1191     // 1.  Check if this start of the sequence.
1192     if (dimInMapStart == dimSequenceStart) {
1193       if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1194         return false;
1195       // 1a. Check if sequence is preserved.
1196       for (const auto &dimInSequence : enumerate(dimSequence)) {
1197         unsigned dimInMap =
1198             indexingMap.getResult(expr.index() + dimInSequence.index())
1199                 .cast<AffineDimExpr>()
1200                 .getPosition();
1201         if (dimInMap != dimInSequence.value())
1202           return false;
1203       }
1204       // Found the sequence. Projected permutation
1205       // enforces that all AffineDimExprs in the result are unique, so no
1206       // further checks are needed.
1207       return true;
1208     }
1209     // 2. If position in the expr (which is of type AffineDimExpr) is part
1210     // of sequence, return false here. This implies the entire sequence does not
1211     // exist in the indexing map.
1212     if (sequenceElements.count(dimInMapStart))
1213       return false;
1214   }
1215   // 3. No element of sequence found. Return true.
1216   return true;
1217 }
1218 
1219 // Return the list of dimensions of the iteration domain that can be
1220 // collapsed to allow for fusion with the a producer that is an expand_shape
1221 // operation. If all dimensions created by expansion can be collapsed in the
1222 // iteration space then the reshape is defunct.
1223 //
1224 // Example:
1225 //
1226 // ```mlir
1227 // #map = affine_map<(d0, d1) -> (d0, d1)>
1228 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1229 // %2 = linalg.init_tensor [..] : tensor<?x4xf32>
1230 // %3 = linalg.generic {
1231 //     indexing_maps = [#map, #map],
1232 //     iterator_types = ["parallel" ,"parallel"]}
1233 //     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1234 // ```
1235 //
1236 // can be fused by collapsing the dimensions of the iteration space.
1237 //
1238 // ```mlir
1239 // #map = affine_map<(d0) -> (d0)>
1240 // %2 = linalg.init_tensor [..] : tensor<?xf32>
1241 // %3 = linalg.generic {
1242 //     indexing_maps = [#map, #map],
1243 //     iterator_types = ["parallel"]}
1244 //     ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1245 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1246 // ```
1247 //
1248 // In the following example,
1249 //
1250 // ```mlir
1251 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
1252 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
1253 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1254 // %2 = linalg.init_tensor [..] : tensor<4x?xf32>
1255 // %2 = linalg.generic {
1256 //     indexing_maps = [#map0, #map1],
1257 //     iterator_types = ["parallel" ,"parallel"]}
1258 //     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1259 // ```
1260 //
1261 // the reshape cannot be fused with the generic op by collapsing the op
1262 // dimensions since the indexing maps will have to contain mods and divs
1263 // to preserve the accesses pattern. When no dimensions of the iteration
1264 // space are collapsable and empty vector is returned.
1265 static SmallVector<ReassociationIndices>
1266 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1267                                  ArrayRef<ReassociationIndices> reassociation) {
1268   // Some basic checks for this fusion to be valid.
1269   if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1)
1270     return {};
1271 
1272   if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) {
1273         return map.isProjectedPermutation();
1274       })) {
1275     return {};
1276   }
1277 
1278   // Compute all the loops with the reduction iterator types.
1279   SmallVector<int64_t> reductionDims;
1280   for (auto iteratorType : llvm::enumerate(genericOp.iterator_types())) {
1281     if (isReductionIterator(iteratorType.value())) {
1282       reductionDims.push_back(iteratorType.index());
1283     }
1284   }
1285 
1286   llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1287   AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand);
1288   auto iteratorTypes = genericOp.iterator_types().getValue();
1289   SmallVector<ReassociationIndices> iterationSpaceReassociation;
1290   for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1291     assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1292 
1293     // Ignore dims that are not folded.
1294     if (foldedRangeDims.size() == 1)
1295       continue;
1296 
1297     ReassociationIndices foldedIterationSpaceDims =
1298         getDomainReassociation(indexingMap, foldedRangeDims);
1299 
1300     // Check that the folded iteration dims do not contain already processed
1301     // dims.
1302     if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1303           return processedIterationDims.count(dim);
1304         }))
1305       continue;
1306 
1307     // Check that all folded iterator types are all parallel or all reductions.
1308     Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
1309     if (!isParallelIterator(startIteratorType) &&
1310         !isReductionIterator(startIteratorType))
1311       continue;
1312     if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1313           return iteratorTypes[dim] != startIteratorType;
1314         }))
1315       continue;
1316 
1317     // If the folded dimensions correspond to a "reduction" iterator type,
1318     // the folded dimensions need to be "in-order". Strictly speaking this is
1319     // not necessary, for reductions that are associative and commutative,  but
1320     // using a more strict definition of reduction for now.
1321     if (isReductionIterator(startIteratorType)) {
1322       bool isContiguous = false;
1323       for (auto startDim : llvm::enumerate(reductionDims)) {
1324         // Move window in `reductionDims` to start of the folded iteration dims.
1325         if (startDim.value() != foldedIterationSpaceDims[0])
1326           continue;
1327         // If sizes doesnt match, trivial not contiguous. This condition should
1328         // not be hit.
1329         if (startDim.index() + foldedIterationSpaceDims.size() >
1330             reductionDims.size())
1331           break;
1332         // Check that the contiguity is maintained.
1333         isContiguous = true;
1334         for (auto foldedDim : llvm::enumerate(foldedIterationSpaceDims)) {
1335           if (reductionDims[foldedDim.index() + startDim.index()] !=
1336               foldedDim.value()) {
1337             isContiguous = false;
1338             break;
1339           }
1340         }
1341         break;
1342       }
1343       if (!isContiguous)
1344         continue;
1345     }
1346 
1347     // Check that the sequence is preserved in all indexing maps.
1348     if (llvm::any_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) {
1349           return !isDimSequencePreserved(indexingMap, foldedIterationSpaceDims);
1350         }))
1351       continue;
1352 
1353     processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1354                                   foldedIterationSpaceDims.end());
1355     iterationSpaceReassociation.emplace_back(
1356         std::move(foldedIterationSpaceDims));
1357   }
1358 
1359   return iterationSpaceReassociation;
1360 }
1361 
1362 /// Helper class to carry state while collapsing the `linalg.generic` op.
1363 namespace {
1364 class CollapsingInfo {
1365 public:
1366   LogicalResult initialize(unsigned origNumLoops,
1367                            ArrayRef<ReassociationIndices> foldedIterationDims) {
1368     llvm::SmallDenseSet<int64_t, 4> processedDims;
1369     // Find all the dims that are folded.
1370     for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1371       if (foldedIterationDim.empty())
1372         continue;
1373       // If the folded dims contain dims already folded, that's illegal
1374       // specification. Repetition within a list is also illegal.
1375       for (auto dim : foldedIterationDim) {
1376         if (dim >= origNumLoops)
1377           return failure();
1378         if (processedDims.count(dim))
1379           return failure();
1380         processedDims.insert(dim);
1381       }
1382       collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1383                                                    foldedIterationDim.end());
1384     }
1385     if (processedDims.size() > origNumLoops)
1386       return failure();
1387 
1388     // Add all the preserved dims of the original op as single
1389     // elements to `collapsedOpToOrigOpIterationDim`.
1390     for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1391       if (processedDims.count(dim))
1392         continue;
1393       collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1394     }
1395 
1396     llvm::sort(collapsedOpToOrigOpIterationDim,
1397                [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
1398                  return lhs[0] < rhs[0];
1399                });
1400     origOpToCollapsedOpIterationDim.resize(origNumLoops);
1401     for (auto foldedDims : llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1402       for (auto dim : enumerate(foldedDims.value()))
1403         origOpToCollapsedOpIterationDim[dim.value()] =
1404             std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1405     }
1406     return success();
1407   }
1408 
1409   /// Return mapping from collapsed loop domain to original loop domain.
1410   ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1411     return collapsedOpToOrigOpIterationDim;
1412   }
1413 
1414   /// Return mapping from original loop domain to collapsed loop domain. The
1415   /// mapping is a pair. First value is the dimension in the collapsed loop that
1416   /// the original loop is mapped to. Second is the relative position in folded
1417   /// list of this domain. For example if the original loop domain is 3D, and
1418   /// the collapsed loop domain is folding all of it, i.e.
1419   ///
1420   /// ```
1421   /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1422   /// ```
1423   ///
1424   /// then
1425   ///
1426   /// ```
1427   ///  origOpToCollapsedOpMapping[0] = {0, 0};
1428   ///  origOpToCollapsedOpMapping[1] = {0, 1};
1429   ///  origOpToCollapsedOpMapping[2] = {0, 2};
1430   ///  origOpToCollapsedOpMapping[3] = {1, 0};
1431   ///  origOpToCollapsedOpMapping[4] = {1, 1};
1432   /// ```
1433   ///
1434   ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1435     return origOpToCollapsedOpIterationDim;
1436   }
1437 
1438   /// Return the collapsed op iteration domain rank.
1439   unsigned getCollapsedOpIterationRank() const {
1440     return collapsedOpToOrigOpIterationDim.size();
1441   }
1442 
1443 private:
1444   /// Map from the iteration domain index in collapsed op to the iteration
1445   /// domain indices in the original op.
1446   SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1447 
1448   /// Map from iteration domain index in the original op to the iteration domain
1449   /// index in the collapsed op.
1450   SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1451 };
1452 } // namespace
1453 
1454 /// Get the iterator types for the collapsed operation given the original
1455 /// iterator types and collapsed dimensions.
1456 static SmallVector<StringRef>
1457 getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
1458                             const CollapsingInfo &collapsingInfo) {
1459   SmallVector<StringRef> collapsedIteratorTypes;
1460   for (ReassociationIndicesRef foldedIterDims :
1461        collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1462     assert(!foldedIterDims.empty() &&
1463            "reassociation indices expected to have non-empty sets");
1464     // Just pick the iterator type of the first folded dim. Pre-condition checks
1465     // expected to have checked that iterator types of all folded dimensions are
1466     // the same.
1467     collapsedIteratorTypes.push_back(
1468         iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue());
1469   }
1470   return collapsedIteratorTypes;
1471 }
1472 
1473 /// Compute the indexing map in the collapsed op that corresponds to the given
1474 /// `indexingMap` of the original operation.
1475 static AffineMap
1476 getCollapsedOpIndexingMap(AffineMap indexingMap,
1477                           const CollapsingInfo &collapsingInfo) {
1478   MLIRContext *context = indexingMap.getContext();
1479   assert(indexingMap.isProjectedPermutation() &&
1480          "expected indexing map to be projected permutation");
1481   SmallVector<AffineExpr> resultExprs;
1482   auto origOpToCollapsedOpMapping =
1483       collapsingInfo.getOrigOpToCollapsedOpMapping();
1484   for (auto expr : indexingMap.getResults()) {
1485     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
1486     // If the dim is not the first of the collapsed dim, do nothing.
1487     if (origOpToCollapsedOpMapping[dim].second != 0)
1488       continue;
1489     // The next n-dims are guaranteed to be collapsed. So just use the
1490     // iteration dimension of the collapsed op.
1491     resultExprs.push_back(
1492         getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1493   }
1494   return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1495                         resultExprs, context);
1496 }
1497 
1498 /// Return the `reassociation` indices to use to collapse the operand when the
1499 /// iteration space of a generic op is collapsed.
1500 static SmallVector<ReassociationIndices>
1501 getOperandReassociation(AffineMap indexingMap,
1502                         const CollapsingInfo &collapsingInfo) {
1503   unsigned counter = 0;
1504   SmallVector<ReassociationIndices> operandReassociation;
1505   auto origOpToCollapsedOpMapping =
1506       collapsingInfo.getOrigOpToCollapsedOpMapping();
1507   auto collapsedOpToOrigOpMapping =
1508       collapsingInfo.getCollapsedOpToOrigOpMapping();
1509   while (counter < indexingMap.getNumResults()) {
1510     unsigned dim =
1511         indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition();
1512     if (origOpToCollapsedOpMapping[dim].second == 0) {
1513       // This is the start of a collapsed dimensions of the iteration that
1514       // is gauranteed to be preserved in the indexing map. The number of folded
1515       // dims is obtained from the collapsed op to original op mapping.
1516       unsigned numFoldedDims =
1517           collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1518               .size();
1519       auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1520       operandReassociation.emplace_back(range.begin(), range.end());
1521       counter += numFoldedDims;
1522     }
1523   }
1524   return operandReassociation;
1525 }
1526 
1527 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1528 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
1529                                    OpOperand *opOperand,
1530                                    const CollapsingInfo &collapsingInfo,
1531                                    OpBuilder &builder) {
1532   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
1533   SmallVector<ReassociationIndices> operandReassociation =
1534       getOperandReassociation(indexingMap, collapsingInfo);
1535 
1536   // If the number of entries in the reassocation for the operand is same as the
1537   // number of results of the indexing map, then nothing to do for this operand.
1538   Value operand = opOperand->get();
1539   if (operandReassociation.size() == indexingMap.getNumResults())
1540     return operand;
1541 
1542   // Insert a reshape to collapse the dimensions.
1543   auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
1544       loc, operand, operandReassociation);
1545   return reshapeOp.getResult();
1546 }
1547 
1548 /// Modify the `linalg.index` operations in the original generic op, to its
1549 /// value in the collapsed operation.
1550 void generateCollapsedIndexingRegion(Location loc, Block *block,
1551                                      const CollapsingInfo &collapsingInfo,
1552                                      ValueRange loopRange,
1553                                      PatternRewriter &rewriter) {
1554   OpBuilder::InsertionGuard g(rewriter);
1555   rewriter.setInsertionPointToStart(block);
1556 
1557   // Collect all the original index ops.
1558   auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1559 
1560   // For each folded dimension list resolve the original induction variable
1561   // values in terms of the folded dimension induction variable.
1562   //   i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1563   // can be inverted to
1564   //   i2 = i_{folded} % d2
1565   //   i1 = (i_{folded} / d2) % d1
1566   //   i0 = i_{folded} / (d1 * d2)
1567   llvm::DenseMap<unsigned, Value> indexReplacementVals;
1568   for (auto &foldedDims :
1569        enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1570     ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1571     Value newIndexVal =
1572         rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1573     for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1574       indexReplacementVals[dim] =
1575           rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1576       newIndexVal =
1577           rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1578     }
1579     indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1580   }
1581 
1582   for (auto indexOp : indexOps) {
1583     auto dim = indexOp.dim();
1584     rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1585   }
1586 }
1587 
1588 /// Implementation of fusion with reshape operation by collapsing dimensions.
1589 static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
1590     GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
1591     OpOperand *fusableOpOperand, PatternRewriter &rewriter) {
1592   // Bail on trivial no-op cases.
1593   if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1594       llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1595         return foldedDims.size() <= 1;
1596       }))
1597     return failure();
1598 
1599   CollapsingInfo collapsingInfo;
1600   if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
1601                                        foldedIterationDims))) {
1602     return rewriter.notifyMatchFailure(
1603         genericOp, "illegal to collapse specified dimensions");
1604   }
1605 
1606   // Get the iterator types for the operand.
1607   SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes(
1608       genericOp.iterator_types().getValue(), collapsingInfo);
1609 
1610   // Get the indexing maps.
1611   auto indexingMaps = llvm::to_vector(
1612       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap map) {
1613         return getCollapsedOpIndexingMap(map, collapsingInfo);
1614       }));
1615 
1616   Location loc = genericOp->getLoc();
1617 
1618   // Get the input operands.
1619   auto inputOperands = llvm::to_vector(
1620       llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) {
1621         return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
1622                                      rewriter);
1623       }));
1624 
1625   // Get the output operands and result types.
1626   SmallVector<Type> resultTypes;
1627   SmallVector<Value> outputOperands;
1628   resultTypes.reserve(genericOp.getNumOutputs());
1629   outputOperands.reserve(genericOp.getNumOutputs());
1630   for (OpOperand *output : genericOp.getOutputOperands()) {
1631     Value newOutput =
1632         getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter);
1633     outputOperands.push_back(newOutput);
1634     resultTypes.push_back(newOutput.getType());
1635   }
1636 
1637   // Create the generic op.
1638   auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
1639       loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1640       iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1641   Block *origOpBlock = &genericOp->getRegion(0).front();
1642   Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
1643   rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1644                        collapsedOpBlock->getArguments());
1645 
1646   if (collapsedGenericOp.hasIndexSemantics()) {
1647     // Collect the loop range of the generic op.
1648     OpBuilder::InsertionGuard g(rewriter);
1649     rewriter.setInsertionPoint(collapsedGenericOp);
1650     SmallVector<Range> loopRanges =
1651         cast<LinalgOp>(genericOp.getOperation())
1652             .createLoopRanges(rewriter, genericOp.getLoc());
1653     assert(llvm::all_of(loopRanges,
1654                         [](Range range) {
1655                           return matchPattern(range.offset, m_Zero()) &&
1656                                  matchPattern(range.stride, m_One());
1657                         }) &&
1658            "expected all loop ranges to have zero start and unit stride");
1659     SmallVector<Value> loopBound = llvm::to_vector(
1660         llvm::map_range(loopRanges, [](Range range) { return range.size; }));
1661     generateCollapsedIndexingRegion(loc,
1662                                     &collapsedGenericOp->getRegion(0).front(),
1663                                     collapsingInfo, loopBound, rewriter);
1664   }
1665 
1666   // Insert expanding reshape for the result to get back the original result
1667   // type.
1668   SmallVector<Value> results;
1669   for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
1670     Value collapsedOpResult =
1671         collapsedGenericOp->getResult(originalResult.index());
1672     auto originalResultType =
1673         originalResult.value().getType().cast<ShapedType>();
1674     auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
1675     if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1676       AffineMap indexingMap =
1677           genericOp.getTiedIndexingMapForResult(originalResult.value());
1678       SmallVector<ReassociationIndices> reassociation =
1679           getOperandReassociation(indexingMap, collapsingInfo);
1680       Value result = rewriter.create<tensor::ExpandShapeOp>(
1681           loc, originalResultType, collapsedOpResult, reassociation);
1682       results.push_back(result);
1683     } else {
1684       results.push_back(collapsedOpResult);
1685     }
1686   }
1687   return results;
1688 }
1689 
1690 namespace {
1691 
1692 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1693 /// contracting dimensions of the loop.
1694 class FoldWithProducerReshapeOpByCollapsing
1695     : public OpRewritePattern<GenericOp> {
1696 public:
1697   FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1698                                         ControlFusionFn foldReshapes,
1699                                         PatternBenefit benefit = 1)
1700       : OpRewritePattern<GenericOp>(context, benefit),
1701         controlFoldingReshapes(std::move(foldReshapes)) {}
1702 
1703   LogicalResult matchAndRewrite(GenericOp genericOp,
1704                                 PatternRewriter &rewriter) const override {
1705     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1706       tensor::ExpandShapeOp reshapeOp =
1707           opOperand->get().getDefiningOp<tensor::ExpandShapeOp>();
1708       if (!reshapeOp)
1709         continue;
1710 
1711       SmallVector<ReassociationIndices> collapsableIterationDims =
1712           getCollapsableIterationSpaceDims(genericOp, opOperand,
1713                                            reshapeOp.getReassociationIndices());
1714       if (collapsableIterationDims.empty() ||
1715           !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
1716         continue;
1717       }
1718 
1719       Optional<SmallVector<Value>> replacements =
1720           collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
1721                                          opOperand, rewriter);
1722       if (!replacements) {
1723         return rewriter.notifyMatchFailure(
1724             genericOp, "failed to do the fusion by collapsing transformation");
1725       }
1726 
1727       rewriter.replaceOp(genericOp, replacements.getValue());
1728       return success();
1729     }
1730     return failure();
1731   }
1732 
1733 private:
1734   ControlFusionFn controlFoldingReshapes;
1735 };
1736 } // namespace
1737 
1738 //===---------------------------------------------------------------------===//
1739 // Methods and patterns to convert tensor.expand_shape -> linalg.generic
1740 // into linalg.generic -> tensor.expand_shape, i.e. push the reshape down.
1741 //===---------------------------------------------------------------------===//
1742 
1743 // TODO(ravishankarm): This pattern is to be deprecated in favor of fusion by
1744 // collapsing that provides a more general functionality. This pattern is very
1745 // specific to a particular use case. The fusion by collapsing can provide the
1746 // same control to clients using the control function there.
1747 
1748 static SmallVector<ReassociationIndices>
1749 getReassociationIndices(ArrayRef<AffineMap> maps) {
1750   SmallVector<ReassociationIndices> reassociation;
1751   for (AffineMap map : maps) {
1752     ReassociationIndices indices;
1753     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
1754       unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
1755       indices.push_back(pos);
1756     }
1757     reassociation.push_back(indices);
1758   }
1759   return reassociation;
1760 }
1761 
1762 namespace {
1763 /// Pattern to move rank reducing reshape after an elementwise linalg generic
1764 /// op. This is useful to expose more fusion opportunities between named ops and
1765 /// generic ops. This can only be done if there is no broadcast or permuation
1766 /// within the dimensions we need to merge.
1767 ///
1768 /// For example,
1769 ///
1770 ///  %0 = tensor.expand_shape %A [[0, 1], [2]]
1771 ///      : tensor<12544x16xf32> into tensor<112x112x16xf32>
1772 ///  %2 = linalg.generic {indexing_maps = [
1773 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1774 ///    affine_map<(d0, d1, d2) -> (d2)>,
1775 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
1776 ///    ["parallel", "parallel", "parallel"]} {
1777 ///  } -> tensor<112x112x16xf32>
1778 ///
1779 ///  into
1780 ///
1781 ///  %2 = linalg.generic {indexing_maps = [
1782 ///    affine_map<(d0, d1) -> (d0, d1)>,
1783 ///    affine_map<(d0, d1) -> (d1)>,
1784 ///    affine_map<(d0, d1) -> (d0, d1)>],
1785 ///    iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
1786 ///    : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
1787 ///  } -> tensor<12544x16xf32>
1788 ///  %3 = tensor.expand_shape %2 [[0, 1], [2]]
1789 ///    : tensor<12544x16xf32> into tensor<112x112x16xf32>
1790 struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
1791   using OpRewritePattern<GenericOp>::OpRewritePattern;
1792 
1793   LogicalResult matchAndRewrite(GenericOp genericOp,
1794                                 PatternRewriter &rewriter) const override {
1795     // Only apply to elementwise linalg on tensor.
1796     if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() ||
1797         genericOp.getNumParallelLoops() != genericOp.getNumLoops())
1798       return failure();
1799     // Only support identity output maps. It could be extended to permuations if
1800     // needed.
1801     if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
1802           return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
1803         }))
1804       return failure();
1805     int64_t destRank = genericOp.getNumParallelLoops();
1806     SmallVector<Value> newOperands = genericOp.getInputOperands();
1807     tensor::ExpandShapeOp reshapeFound;
1808     // 1. Look for tensor_expand_shape operands and figure out save the
1809     // dimensions merged.
1810     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
1811     for (const auto &en : llvm::enumerate(inputOperands)) {
1812       auto reshapeOp =
1813           en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>();
1814       if (!reshapeOp)
1815         continue;
1816       // TODO: We could support non-identity map as long as the merged
1817       // dimensions are still contiguous.
1818       if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
1819         continue;
1820       if (reshapeFound) {
1821         // Only support a second reshape op if it has the same reassociate maps.
1822         if (reshapeFound.getReassociationMaps() ==
1823             reshapeOp.getReassociationMaps())
1824           newOperands[en.index()] = reshapeOp.src();
1825         continue;
1826       }
1827       reshapeFound = reshapeOp;
1828       newOperands[en.index()] = reshapeOp.src();
1829     }
1830     if (!reshapeFound)
1831       return failure();
1832 
1833     // Calculate the reassociation indices and rassociated reverse map.
1834     SmallVector<ReassociationIndices> reassociation =
1835         getReassociationIndices(reshapeFound.getReassociationMaps());
1836     SmallVector<unsigned> remap(destRank);
1837     for (auto &indices : llvm::enumerate(reassociation)) {
1838       for (int64_t index : indices.value()) {
1839         remap[index] = indices.index();
1840       }
1841     }
1842     // 2. Verify that we can merge the dimensions in the linalg and that we
1843     // don't need to create new reshapes operands. Inserting new reshape
1844     // operands would defeat the purpose of the transformation.
1845     for (const auto &en : llvm::enumerate(inputOperands)) {
1846       if (en.value()->get() == newOperands[en.index()]) {
1847         AffineMap map = genericOp.getTiedIndexingMap(en.value());
1848         for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
1849           if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
1850             return failure();
1851         }
1852       }
1853     }
1854 
1855     // 3. Calculate the affine map remapping and the reassociation to apply to
1856     // output tensors.
1857     SmallVector<AffineMap> newMaps;
1858     unsigned newRank = reassociation.size();
1859     for (auto map : genericOp.getIndexingMaps()) {
1860       SmallVector<AffineExpr> newExprs;
1861       for (auto expr : map.getResults()) {
1862         unsigned position = expr.template cast<AffineDimExpr>().getPosition();
1863         // Skip dimension merged except for the last of the group.
1864         if (reassociation[remap[position]].back() == position) {
1865           newExprs.push_back(
1866               getAffineDimExpr(remap[position], genericOp.getContext()));
1867         }
1868       }
1869       newMaps.push_back(
1870           AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
1871     }
1872 
1873     // 4. Reshape the output tensors.
1874     SmallVector<Value> newOutputs;
1875     SmallVector<Type> newOutputTypes;
1876     for (auto output : genericOp.outputs()) {
1877       auto newOutputType = RankedTensorType::get(
1878           reshapeFound.getSrcType().getShape(),
1879           output.getType().template cast<RankedTensorType>().getElementType());
1880       Value newOutput = rewriter.create<tensor::CollapseShapeOp>(
1881           genericOp->getLoc(), newOutputType, output, reassociation);
1882       newOutputTypes.push_back(newOutputType);
1883       newOutputs.push_back(newOutput);
1884     }
1885     // 5. Create a new generic op with lowerer rank.
1886     SmallVector<StringRef> iteratorTypes(newRank,
1887                                          getParallelIteratorTypeName());
1888     auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
1889                                             newOperands, newOutputs, newMaps,
1890                                             iteratorTypes);
1891     rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
1892                                 newOp.region().begin());
1893     // 6. Reshape the so that the type matches the uses.
1894     SmallVector<Value> newResults;
1895     for (const auto &result : llvm::enumerate(newOp->getResults())) {
1896       newResults.push_back(rewriter.create<tensor::ExpandShapeOp>(
1897           genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
1898           result.value(), reassociation));
1899     }
1900     rewriter.replaceOp(genericOp, newResults);
1901     return success();
1902   }
1903 };
1904 } // namespace
1905 
1906 //===---------------------------------------------------------------------===//
1907 // Methods and patterns that fuse constants with linalg.generic operations.
1908 //===---------------------------------------------------------------------===//
1909 
1910 namespace {
1911 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1912 /// handle cases where the constant is not single-valued.
1913 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1914 public:
1915   FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
1916       : OpRewritePattern<GenericOp>(context, benefit) {}
1917 
1918   LogicalResult matchAndRewrite(GenericOp genericOp,
1919                                 PatternRewriter &rewriter) const override {
1920     if (!genericOp.hasTensorSemantics())
1921       return failure();
1922     for (OpOperand *opOperand : genericOp.getInputOperands()) {
1923       Operation *def = opOperand->get().getDefiningOp();
1924       Attribute constantAttr;
1925       auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1926         {
1927           DenseElementsAttr splatAttr;
1928           if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1929               splatAttr.isSplat() &&
1930               splatAttr.getType().getElementType().isIntOrFloat()) {
1931             constantAttr = splatAttr.getSplatValue<Attribute>();
1932             return true;
1933           }
1934         }
1935         {
1936           IntegerAttr intAttr;
1937           if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1938             constantAttr = intAttr;
1939             return true;
1940           }
1941         }
1942         {
1943           FloatAttr floatAttr;
1944           if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1945             constantAttr = floatAttr;
1946             return true;
1947           }
1948         }
1949         return false;
1950       };
1951 
1952       auto resultValue = opOperand->get().dyn_cast<OpResult>();
1953       if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1954         continue;
1955 
1956       // The operands and the indexing_maps of the fused operation the same as
1957       // the operands and indexing_maps of the generic operations with the
1958       // values at the constant index dropped.
1959       SmallVector<AffineMap> fusedIndexMaps;
1960       SmallVector<Value> fusedOperands;
1961       SmallVector<Location> fusedLocs{genericOp.getLoc()};
1962       fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1963       fusedOperands.reserve(genericOp.getNumInputs());
1964       fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
1965       for (OpOperand *inputOperand : genericOp.getInputOperands()) {
1966         if (inputOperand == opOperand)
1967           continue;
1968         Value inputValue = inputOperand->get();
1969         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1970         fusedOperands.push_back(inputValue);
1971         fusedLocs.push_back(inputValue.getLoc());
1972       }
1973       for (OpOperand *outputOperand : genericOp.getOutputOperands())
1974         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1975 
1976       // Check if the operation shapes to loops map is computable.
1977       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1978         return rewriter.notifyMatchFailure(
1979             genericOp, "fused op loop bound computation failed");
1980       }
1981 
1982       // Create a constant scalar value from the splat constant.
1983       Value scalarConstant = rewriter.create<arith::ConstantOp>(
1984           def->getLoc(), constantAttr, constantAttr.getType());
1985 
1986       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1987       auto fusedOp = rewriter.create<GenericOp>(
1988           rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1989           /*inputs=*/fusedOperands,
1990           /*outputs=*/outputOperands,
1991           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1992           genericOp.iterator_types(),
1993           /*doc=*/nullptr,
1994           /*library_call=*/nullptr);
1995 
1996       // Map the block argument corresponding to the replaced argument with the
1997       // scalar constant.
1998       Region &region = genericOp->getRegion(0);
1999       Block &entryBlock = *region.begin();
2000       BlockAndValueMapping mapping;
2001       mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
2002                   scalarConstant);
2003       Region &fusedRegion = fusedOp->getRegion(0);
2004       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
2005                                  mapping);
2006       rewriter.replaceOp(genericOp, fusedOp->getResults());
2007       return success();
2008     }
2009     return failure();
2010   }
2011 };
2012 
2013 } // namespace
2014 
2015 //===---------------------------------------------------------------------===//
2016 // Miscellaneous patterns that help fusion.
2017 //===---------------------------------------------------------------------===//
2018 
2019 namespace {
2020 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
2021 /// the value of the `outs` operand is not used within the op.  This is only
2022 /// implemented for `linalg.generic` operations for now, but should hold for all
2023 /// linalg structured ops.
2024 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2025   using OpRewritePattern<GenericOp>::OpRewritePattern;
2026 
2027   LogicalResult matchAndRewrite(GenericOp op,
2028                                 PatternRewriter &rewriter) const override {
2029     rewriter.startRootUpdate(op);
2030     bool modifiedOutput = false;
2031     Location loc = op.getLoc();
2032     for (OpOperand *opOperand : op.getOutputOperands()) {
2033       if (!op.payloadUsesValueFromOperand(opOperand)) {
2034         Value operandVal = opOperand->get();
2035         auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
2036         if (!operandType)
2037           continue;
2038 
2039         // If outs is sparse, leave it to the sparse compiler.
2040         if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
2041           continue;
2042 
2043         // If outs is already an `init_tensor` operation, nothing to do.
2044         auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
2045         if (definingOp)
2046           continue;
2047         modifiedOutput = true;
2048         SmallVector<Value> dynamicDims;
2049         for (const auto &dim : llvm::enumerate(operandType.getShape())) {
2050           if (dim.value() != ShapedType::kDynamicSize)
2051             continue;
2052           dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
2053               loc, operandVal, dim.index()));
2054         }
2055         Value initTensor = rewriter.create<InitTensorOp>(
2056             loc, dynamicDims, operandType.getShape(),
2057             operandType.getElementType());
2058         op->setOperand(opOperand->getOperandNumber(), initTensor);
2059       }
2060     }
2061     if (!modifiedOutput) {
2062       rewriter.cancelRootUpdate(op);
2063       return failure();
2064     }
2065     rewriter.finalizeRootUpdate(op);
2066     return success();
2067   }
2068 };
2069 
2070 /// Fold linalg.fill into linalg.generic
2071 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2072   using OpRewritePattern<GenericOp>::OpRewritePattern;
2073 
2074   LogicalResult matchAndRewrite(GenericOp genericOp,
2075                                 PatternRewriter &rewriter) const override {
2076     if (!genericOp.hasTensorSemantics())
2077       return failure();
2078     bool fillFound = false;
2079     Block &payload = genericOp.region().front();
2080     for (OpOperand *opOperand : genericOp.getInputOperands()) {
2081       if (!genericOp.payloadUsesValueFromOperand(opOperand))
2082         continue;
2083       FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2084       if (!fillOp)
2085         continue;
2086       fillFound = true;
2087       payload.getArgument(opOperand->getOperandNumber())
2088           .replaceAllUsesWith(fillOp.value());
2089     }
2090     return success(fillFound);
2091   }
2092 };
2093 } // namespace
2094 //===---------------------------------------------------------------------===//
2095 // Methods that add patterns described in this file to a pattern list.
2096 //===---------------------------------------------------------------------===//
2097 
2098 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
2099     RewritePatternSet &patterns) {
2100   patterns.add<
2101       FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
2102       FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
2103       FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>>(
2104       patterns.getContext());
2105 }
2106 
2107 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
2108     RewritePatternSet &patterns) {
2109   patterns
2110       .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
2111            FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>,
2112            FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>>(
2113           patterns.getContext());
2114 }
2115 
2116 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
2117     RewritePatternSet &patterns,
2118     const ControlFusionFn &controlFoldingReshapes) {
2119   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2120                                                     controlFoldingReshapes);
2121   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2122                                                      controlFoldingReshapes);
2123 }
2124 
2125 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
2126     RewritePatternSet &patterns,
2127     const ControlFusionFn &controlFoldingReshapes) {
2128   patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2129                                                       controlFoldingReshapes);
2130 }
2131 
2132 void mlir::linalg::populateElementwiseOpsFusionPatterns(
2133     RewritePatternSet &patterns,
2134     const ControlFusionFn &controlElementwiseOpsFusion) {
2135   auto *context = patterns.getContext();
2136   patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2137   patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2138                RemoveOutsDependency>(context);
2139 }
2140 
2141 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
2142   auto *context = patterns.getContext();
2143   patterns.add<PushExpandingReshape>(context);
2144 }
2145 
2146 //===---------------------------------------------------------------------===//
2147 // Passes
2148 //===---------------------------------------------------------------------===//
2149 
2150 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
2151                                       OpOperand &consumer) {
2152   if (auto producerCollapseOp =
2153           dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
2154     return !isUnitDimExpansionOnly(producerCollapseOp);
2155   }
2156   if (auto consumerExpandOp =
2157           dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
2158     return !isUnitDimExpansionOnly(consumerExpandOp);
2159   }
2160   return true;
2161 }
2162 
2163 namespace {
2164 
2165 /// Pass that fuses generic ops on tensors. Used only for testing.
2166 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2167 // patterns added here heavily depends on the cost function used. Having an
2168 // opinionated pass of this form is not recommended. Deprecate this pass in
2169 // favor of test passes that check the functionality of each of the patterns
2170 // added here individually.
2171 struct LinalgElementwiseOpFusionPass
2172     : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
2173   void runOnOperation() override {
2174     Operation *op = getOperation();
2175     MLIRContext *context = op->getContext();
2176     RewritePatternSet patterns(context);
2177 
2178     // Add folding with reshape by expansion patterns.
2179     ControlFusionFn defaultControlFn = [](const OpResult &producer,
2180                                           const OpOperand &consumer) {
2181       return producer.hasOneUse();
2182     };
2183 
2184     // Add elementwise op fusion patterns.
2185     populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
2186 
2187     populateFoldReshapeOpsByExpansionPatterns(
2188         patterns,
2189         allowFoldingUnitDimReshapes ? defaultControlFn : skipUnitDimReshape);
2190 
2191     // Add the sparse tensor rewriting patterns.
2192     populateSparseTensorRewriting(patterns);
2193 
2194     // General canonicalization patterns.
2195     AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2196     GenericOp::getCanonicalizationPatterns(patterns, context);
2197     tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2198     tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2199     context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2200         patterns);
2201 
2202     // Add constant folding patterns.
2203     populateConstantFoldLinalgOperations(patterns, defaultControlFn);
2204 
2205     // Use TopDownTraversal for compile time reasons
2206     GreedyRewriteConfig grc;
2207     grc.useTopDownTraversal = true;
2208     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
2209                                        grc);
2210   }
2211 };
2212 
2213 /// Pass to test folding of reshape ops with generic ops by linearization.
2214 struct FoldReshapeOpsByLinearizationPass
2215     : public LinalgFoldReshapeOpsByLinearizationBase<
2216           FoldReshapeOpsByLinearizationPass> {
2217   void runOnOperation() override {
2218     Operation *op = getOperation();
2219     RewritePatternSet patterns(op->getContext());
2220     populateFoldReshapeOpsByLinearizationPatterns(patterns);
2221     if (allowFoldingUnitDimReshapes) {
2222       populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
2223     }
2224     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
2225   }
2226 };
2227 
2228 } // namespace
2229 
2230 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
2231   return std::make_unique<LinalgElementwiseOpFusionPass>();
2232 }
2233 
2234 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
2235   return std::make_unique<FoldReshapeOpsByLinearizationPass>();
2236 }
2237