1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "PassDetail.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Support/LLVM.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 
26 using namespace mlir;
27 using namespace mlir::linalg;
28 
29 /// Conditions for elementwise fusion of generic operations.
30 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
31                                      OpOperand *consumerOpOperand) {
32   // Producer and consumer must have tensor semantics.
33   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
34     return false;
35 
36   // Verify that
37   // - the producer has all "parallel" iterator type.
38   if (producer.getNumParallelLoops() != producer.getNumLoops())
39     return false;
40 
41   // Only allow fusing the producer of an input operand for now.
42   // TODO: allow fusing the producer of an output operand.
43   if (!consumer.isInputTensor(consumerOpOperand))
44     return false;
45 
46   // Get the consumer index map. The number of results of the consumer index
47   // map must match the number of loops of the producer.
48   AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
49   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
50     return false;
51 
52   // Currently support only operations with single result.
53   if (producer.getNumOutputs() != 1)
54     return false;
55 
56   // Finally the index_map for the result must be invertible. For now just
57   // verify it is a permutation.
58   AffineMap producerResultIndexMap =
59       producer.getTiedIndexingMap(producer.getOutputOperand(0));
60   return producerResultIndexMap.isPermutation();
61 }
62 
63 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
64 /// the `producer` to use in the fused operation given the indexing map of the
65 /// result of the producer in the consumer.
66 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
67     OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
68     AffineMap fusedConsumerArgIndexMap) {
69   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
70   // from consumer loop -> consumer arg tensor index/producer result tensor
71   // index. The fused loop is same as the consumer loop. For each producer arg
72   // the indexing map to be computed is a map from consumer loop -> producer
73   // arg tensor index.
74   // producerResultIndexMap is a map from producer loop -> tensor index.
75   // Compute the inverse to get map from tensor index -> producer loop.
76   // The inverse is a map from producer result tensor index -> producer loop.
77   AffineMap invProducerResultIndexMap =
78       inversePermutation(producerResultIndexMap);
79   assert(invProducerResultIndexMap &&
80          "expected producer result indexig map to be invertible");
81 
82   LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
83   // argMap is a map from producer loop -> producer arg tensor index.
84   AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
85 
86   // Compose argMap with invProducerResultIndexMap to get a map from
87   // producer result tensor index -> producer arg tensor index.
88   AffineMap t1 = argMap.compose(invProducerResultIndexMap);
89 
90   // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
91   // consumer loop/ fused loop -> producer arg tensor index.
92   return t1.compose(fusedConsumerArgIndexMap);
93 }
94 
95 /// Generate the region of the fused tensor operation. The region of the fused
96 /// op must be empty.
97 static void
98 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
99                                  AffineMap consumerToProducerLoopsMap,
100                                  OpOperand *consumerOpOperand,
101                                  unsigned nloops) {
102   auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
103   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
104   // Build the region of the fused op.
105   Block &producerBlock = producer->getRegion(0).front();
106   Block &consumerBlock = consumer->getRegion(0).front();
107   Block *fusedBlock = new Block();
108   fusedOp.region().push_back(fusedBlock);
109   BlockAndValueMapping mapper;
110   OpBuilder::InsertionGuard guard(rewriter);
111   rewriter.setInsertionPointToStart(fusedBlock);
112 
113   // 2. Add an index operation for every fused loop dimension and use the
114   // `consumerToProducerLoopsMap` to map the producer indices.
115   if (producer.hasIndexSemantics()) {
116     // Add an index operation for every fused loop dimension.
117     unsigned numFusedOpLoops =
118         std::max(producer.getNumLoops(), consumer.getNumLoops());
119     SmallVector<Value> fusedIndices;
120     fusedIndices.reserve(numFusedOpLoops);
121     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
122                     std::back_inserter(fusedIndices), [&](uint64_t dim) {
123                       return rewriter.create<IndexOp>(producer.getLoc(), dim);
124                     });
125     for (IndexOp indexOp :
126          llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
127       Value newIndex = rewriter.create<mlir::AffineApplyOp>(
128           producer.getLoc(),
129           consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
130       mapper.map(indexOp.getResult(), newIndex);
131     }
132   }
133   // TODO: allow fusing the producer of an output operand.
134   assert(consumer.isInputTensor(consumerOpOperand) &&
135          "expected producer of input operand");
136   // 3. Consumer input operands up to consumerIdx (exclusive).
137   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
138            consumerOpOperand->getOperandNumber())) // input assumption.
139     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
140 
141   // Replacing consumerIdx requires getting the cloned, yielded, value from
142   // the (cloned) producer block. This happens in step 9.
143 
144   // 4. Splice in producer's input operands.
145   for (BlockArgument bbArg :
146        producerBlock.getArguments().take_front(producer.getNumInputs()))
147     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
148 
149   // 4.b. Producer output operand/map that is fused needs to be mapped to the
150   // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
151   assert(producer->getNumResults() == 1 && "expected single result producer");
152   if (producer.isInitTensor(producer.getOutputOperand(0))) {
153     BlockArgument bbArg = producerBlock.getArguments()
154                               .drop_front(producer.getNumInputs())
155                               // TODO: bbArg index of
156                               .front();
157     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
158   }
159   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
160   for (BlockArgument bbArg :
161        consumerBlock.getArguments()
162            .take_front(consumer.getNumInputs())
163            .drop_front(consumerOpOperand->getOperandNumber() + 1))
164     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
165   // 6. All of consumer's output operands.
166   for (BlockArgument bbArg :
167        consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
168     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
169   // 7. All of producer's output operands except the one fused.
170   // TODO: allow fusion of multi-result producers.
171   assert(producer->getNumResults() == 1 && "expected single result producer");
172 
173   // 8. Clone all producer operations except for the yield and index operations
174   // to the fused operation.
175   for (auto &op : producerBlock.without_terminator()) {
176     if (!isa<IndexOp>(op))
177       rewriter.clone(op, mapper);
178   }
179   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
180   // forward the yield operand.
181   auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
182   // TODO: allow fusion of multi-result producers.
183   assert(producer->getNumResults() == 1 && "expected single result producer");
184   unsigned producerResultNumber = 0;
185   Value replacement =
186       mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
187   // Sanity checks, if replacement is not already in the mapper then it must be
188   // produced outside.
189   if (replacement == yieldOp.getOperand(producerResultNumber)) {
190     if (auto bb = replacement.dyn_cast<BlockArgument>())
191       assert(bb.getOwner() != &producerBlock &&
192              "yielded block argument must have been mapped");
193     else
194       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
195              "yielded value must have been mapped");
196   }
197   mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
198              replacement);
199   // 10. Clone operations from the consumer to the fused op.
200   for (auto &op : consumerBlock.getOperations())
201     rewriter.clone(op, mapper);
202 
203   // Sanity checks.
204   assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
205          "Ill-formed GenericOp region");
206 }
207 
208 static Optional<SmallVector<Value>>
209 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
210                        const ControlElementwiseOpsFusionFn &controlFn,
211                        PatternRewriter &rewriter) {
212   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
213   if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
214       !controlFn(producer->getResult(0), *consumerOpOperand))
215     return llvm::None;
216 
217   // TODO: allow fusing the producer of an output operand.
218   assert(consumer.isInputTensor(consumerOpOperand) &&
219          "expected producer of input operand");
220 
221   // Compute the fused operands list and indexing maps.
222   SmallVector<Value> fusedOperands;
223   SmallVector<AffineMap> fusedIndexMaps;
224   fusedOperands.reserve(producer->getNumOperands() +
225                         consumer->getNumOperands());
226   fusedIndexMaps.reserve(producer->getNumOperands() +
227                          consumer->getNumOperands());
228   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
229   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
230   SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
231   SmallVector<OpOperand *>::iterator it =
232       llvm::find(consumerInputs, consumerOpOperand);
233   assert(it != consumerInputs.end() && "expected to find the consumer operand");
234   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
235     fusedOperands.push_back(opOperand->get());
236     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
237   }
238   // 4. Splice in producer's input operands/maps.
239   assert(producer->getNumResults() == 1 && "expected single result producer");
240   AffineMap producerResultIndexMap =
241       producer.getTiedIndexingMap(producer.getOutputOperand(0));
242   for (OpOperand *opOperand : producer.getInputOperands()) {
243     fusedOperands.push_back(opOperand->get());
244     // Compute indexing maps for the producer args in the fused operation.
245     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
246         opOperand, producerResultIndexMap,
247         consumer.getTiedIndexingMap(consumerOpOperand));
248     fusedIndexMaps.push_back(map);
249   }
250   // 4.b. Producer output operand/map that is fused needs to be passed if it is
251   // an "initTensor" (i.e. its value is actually read).
252   assert(producer->getNumResults() == 1 && "expected single result producer");
253   if (producer.isInitTensor(producer.getOutputOperand(0))) {
254     fusedOperands.push_back(producer.getOutputOperand(0)->get());
255     // Compute indexing maps for the producer args in the fused operation.
256     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
257         producer.getOutputOperand(0), producerResultIndexMap,
258         consumer.getTiedIndexingMap(consumerOpOperand));
259     fusedIndexMaps.push_back(map);
260   }
261   // 5. Remaining consumer's input operands/maps (drop past index
262   // `consumerIdx`).
263   for (OpOperand *opOperand :
264        llvm::make_range(std::next(it), consumerInputs.end())) {
265     fusedOperands.push_back(opOperand->get());
266     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
267   }
268   // 6. All of consumer's output operands (skip operands: added by the builder).
269   for (OpOperand *opOperand : consumer.getOutputOperands())
270     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
271   // 7. All of producer's output operands/maps except the one fused.
272   // TODO: allow fusion of multi-result producers.
273   assert(producer->getNumResults() == 1 && "expected single result producer");
274 
275   // Generate the fused op.
276   SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
277   auto fusedOp = rewriter.create<GenericOp>(
278       consumer.getLoc(), consumer->getResultTypes(),
279       /*inputs=*/fusedOperands,
280       // TODO: handle outputs.
281       consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
282       consumer.iterator_types(),
283       /*doc=*/nullptr,
284       /*library_call=*/nullptr);
285 
286   // Construct an AffineMap from consumer loops to producer loops.
287   // consumer loop -> tensor index
288   AffineMap consumerResultIndexMap =
289       consumer.getTiedIndexingMap(consumerOpOperand);
290   // tensor index -> producer loop
291   AffineMap invProducerResultIndexMap =
292       inversePermutation(producerResultIndexMap);
293   assert(invProducerResultIndexMap &&
294          "expected producer result indexig map to be invertible");
295   // consumer loop -> producer loop
296   AffineMap consumerToProducerLoopsMap =
297       invProducerResultIndexMap.compose(consumerResultIndexMap);
298 
299   generateFusedElementwiseOpRegion(rewriter, fusedOp,
300                                    consumerToProducerLoopsMap,
301                                    consumerOpOperand, consumer.getNumLoops());
302   return SmallVector<Value>(fusedOp->getResults());
303 }
304 
305 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
306 /// provided, given the shape of the source tensor that corresponds to the
307 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
308 /// are "row-major" ordered logically.
309 ///
310 /// For example:
311 ///
312 /// %0 = op ... : tensor<?x?x4x5xf32>
313 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
314 ///
315 /// and reshape:
316 /// %1 = linalg.tensor_collapse_shape %0 [[0], [0, 1, 2]] :
317 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
318 ///
319 /// would be rewritten into:
320 /// %0 = op ... : tensor<?x?x4x5xf32>
321 /// with output index_map
322 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
323 template <typename TensorReshapeOp>
324 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
325                                         TensorReshapeOp reshapeOp) {
326   constexpr bool isExpanding =
327       std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value;
328   ArrayRef<int64_t> sourceShape =
329       (isExpanding ? reshapeOp.getResultType().getShape()
330                    : reshapeOp.getSrcType().getShape());
331   SmallVector<AffineExpr> resultExprs;
332   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
333   MLIRContext *context = sourceMap.getContext();
334 
335   // Compute the result exprs based on the reassociation maps.
336   for (auto &indices : reshapeOp.getReassociationIndices()) {
337     // Assume that they are in-order and contiguous (already checked in
338     // verifier).
339     assert(!indices.empty());
340     SmallVector<int64_t> sizes;
341     SmallVector<AffineExpr> dimExprs;
342     for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
343                              sourceExprs.slice(indices[0], indices.size()))) {
344       if (std::get<0>(en) == 1)
345         continue;
346       sizes.push_back(std::get<0>(en));
347       dimExprs.push_back(std::get<1>(en));
348     }
349     AffineExpr linearizedExpr =
350         makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
351     resultExprs.push_back(linearizedExpr);
352   }
353   return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
354                         resultExprs, context);
355 }
356 
357 // TensorExpandShapeOp is fusable with its consumer (i.e. reshape as a
358 // producer). Fusing when operand has higher rank will require use of mods and
359 // divs in the indexing maps of the fused op which would make it non-invertible.
360 static bool isTensorReshapeOpFoldableByLinearization(
361     TensorExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
362   if (!asProducer)
363     return false;
364   return useIndexMap.isPermutation();
365 }
366 
367 // TensorCollapseShapeOp is fusable with its producer (i.e. reshape as a
368 // consumer).
369 static bool isTensorReshapeOpFoldableByLinearization(
370     TensorCollapseShapeOp collapseOp, AffineMap useIndexMap, bool asProducer) {
371   if (asProducer)
372     return false;
373   return useIndexMap.isPermutation();
374 }
375 
376 /// Check if the reshape operation is only expansion into/collapsing of
377 /// unit-dimension.
378 template <typename TensorReshapeOp>
379 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
380   constexpr bool isExpanding =
381       std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value;
382   ArrayRef<int64_t> expandedShape =
383       (isExpanding ? reshapeOp.getResultType().getShape()
384                    : reshapeOp.getSrcType().getShape());
385   for (auto &indices : reshapeOp.getReassociationIndices()) {
386     unsigned numUnitDims = 0;
387     for (int64_t position : indices)
388       if (expandedShape[position] == 1)
389         numUnitDims++;
390     if (numUnitDims != indices.size() - 1)
391       return false;
392   }
393   return true;
394 }
395 
396 /// Conditions for folding a generic operation with a reshape op by expanding
397 /// the iteration space dimensionality for tensor operations. These are
398 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
399 /// following fusion pattern.
400 ///
401 ///  Consider
402 ///
403 ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
404 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
405 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
406 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
407 ///  %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]]
408 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
409 ///
410 ///  The reshape can be folded into the `genericOp` if its loop dimensionality
411 ///  is increased to match the result (operand) of the tensor_expand_shape.
412 ///  The indexing_map of the fused tensor in the `genericOp` and the
413 ///  reassociation map helps compute the indexing maps of the modified op.
414 ///  For the above example, based on the reassociation map it
415 ///  can be concluded that
416 ///
417 ///  - The loop used to access the first dimension of the fused tensor is split
418 ///    into two.
419 ///  - The loop used to access the second dimension of the fused tensor is kept
420 ///    as is.
421 ///  - The loop used to access the third dimension of the fused tensor is split
422 ///    into three.
423 ///
424 ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
425 ///  op, then
426 ///
427 ///   d0 -> e0, e1
428 ///   d1 -> e2, e3, e4
429 ///   d2 -> e5
430 ///
431 ///  substituting this, the generic op can be rewritten as
432 ///
433 ///  %d = linalg.generic ins(%0, %1 : )
434 ///        indexing_maps =
435 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
436 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
437 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
438 ///
439 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
440 ///  to make it consistent
441 ///
442 ///  %0 = linalg.tensor_expand_shape %a [[0, 1, 2], [3, 4], [5]]
443 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
444 ///  %1 = linalg.tensor_expand_shape %b [[0, 1, 2], [3]]
445 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
446 ///
447 ///  The added reshapes are again expanding patterns, so they will get fused
448 ///  with its producers if possible.
449 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
450                                                OpOperand *fusableOpOperand) {
451   // Is fusable only if:
452   // - All the indexing maps for operands and results are projected
453   //   permutations.
454   // - The fused tensor is not a scalar.
455   // - All the loops are parallel loops.
456   return genericOp.hasTensorSemantics() &&
457          llvm::all_of(genericOp.indexing_maps().getValue(),
458                       [](Attribute attr) {
459                         return attr.cast<AffineMapAttr>()
460                             .getValue()
461                             .isProjectedPermutation();
462                       }) &&
463          genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
464          llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
465            return attr.cast<StringAttr>().getValue() ==
466                   getParallelIteratorTypeName();
467          });
468 }
469 
470 namespace {
471 /// Information needed to expand a generic operation to fold the reshape with
472 /// it.
473 class ExpansionInfo {
474 public:
475   // Computes the mapping from original dimensions of the op to the dimensions
476   // of the expanded op given the `indexingMap` of the fused operand/result of
477   // the generic op, the `reassocationMaps` of the reshape op and the shape of
478   // the expanded op.
479   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
480                         ArrayRef<AffineMap> reassociationMaps,
481                         ArrayRef<int64_t> expandedShape,
482                         PatternRewriter &rewriter);
483   unsigned getOrigOpNumDims() const { return reassociation.size(); }
484   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
485   ReassociationIndicesRef getExpandedDims(unsigned i) const {
486     return reassociation[i];
487   }
488   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
489     return expandedShapeMap[i];
490   }
491 
492 private:
493   /// Reassociation from the dimensions in the original operation to the
494   /// dimension of the expanded operation.
495   SmallVector<ReassociationIndices> reassociation;
496   /// Mapping from extent of loops in the original operation, to the extent of
497   /// loops in the expanded operation.
498   SmallVector<SmallVector<int64_t>> expandedShapeMap;
499   unsigned expandedOpNumDims;
500 };
501 } // namespace
502 
503 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
504                                      OpOperand *fusableOpOperand,
505                                      ArrayRef<AffineMap> reassociationMaps,
506                                      ArrayRef<int64_t> expandedShape,
507                                      PatternRewriter &rewriter) {
508   if (reassociationMaps.empty())
509     return failure();
510   AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
511 
512   Optional<SmallVector<int64_t, 4>> originalLoopRange =
513       linalgOp.getStaticLoopRanges();
514   if (!originalLoopRange)
515     return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
516 
517   reassociation.clear();
518   expandedShapeMap.clear();
519   // Compute the number of dimension in the expanded op that correspond to each
520   // dimension of the original op.
521   SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
522   expandedShapeMap.resize(fusedIndexMap.getNumDims());
523   for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
524     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
525     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
526     numExpandedDims[pos] = foldedDims.getNumResults();
527     ArrayRef<int64_t> shape =
528         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
529     expandedShapeMap[pos].assign(shape.begin(), shape.end());
530   }
531   // The remaining dimensions remain the same.
532   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
533     if (expandedShapeMap[i].empty())
534       expandedShapeMap[i] = {(*originalLoopRange)[i]};
535 
536   // Compute reassociation map from the original op to the expanded op.
537   unsigned sum = 0;
538   reassociation.reserve(fusedIndexMap.getNumDims());
539   for (auto numFoldedDim : llvm::enumerate(numExpandedDims)) {
540     auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
541     reassociation.emplace_back(seq.begin(), seq.end());
542     sum += numFoldedDim.value();
543   }
544   expandedOpNumDims = sum;
545   return success();
546 }
547 
548 /// Epanding the body of a linalg operation requires adaptations of the accessed
549 /// loop indices. Specifically, access of indices in the original operation need
550 /// to be replaced with linearizations of indices in the expanded op. That
551 /// requires the shape of the expanded dimensions to be static (at least all but
552 /// the most significant). For now check that these are all statically sized.
553 /// Note that this could be extended to handle dynamic case, but the
554 /// implementation below uses `affine.apply` which seems to have issues when the
555 /// shapes are not static.
556 LogicalResult isGenericOpExpandable(GenericOp genericOp,
557                                     const ExpansionInfo &expansionInfo,
558                                     PatternRewriter &rewriter) {
559   if (!genericOp.hasIndexSemantics())
560     return success();
561   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
562     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
563     if (expandedShape.size() == 1)
564       continue;
565     for (int64_t shape : expandedShape.drop_front()) {
566       if (ShapedType::isDynamic(shape)) {
567         return rewriter.notifyMatchFailure(
568             genericOp, "cannot expand due to index semantics and dynamic dims");
569       }
570     }
571   }
572   return success();
573 }
574 
575 /// Return the indexing map to use in the expanded op for a given the
576 /// `indexingMap` of the original operation.
577 static AffineMap
578 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
579                            const ExpansionInfo &expansionInfo) {
580   SmallVector<AffineExpr> newExprs;
581   for (AffineExpr expr : indexingMap.getResults()) {
582     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
583     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
584         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
585           return builder.getAffineDimExpr(static_cast<unsigned>(v));
586         }));
587     newExprs.append(expandedExprs.begin(), expandedExprs.end());
588   }
589   return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
590                         indexingMap.getNumSymbols(), newExprs,
591                         builder.getContext());
592 }
593 
594 /// Return the type of the operand/result to use in the expanded op given the
595 /// type in the original op.
596 static RankedTensorType getExpandedType(RankedTensorType originalType,
597                                         AffineMap indexingMap,
598                                         const ExpansionInfo &expansionInfo) {
599   SmallVector<int64_t> expandedShape;
600   for (AffineExpr expr : indexingMap.getResults()) {
601     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
602     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
603     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
604   }
605   return RankedTensorType::get(expandedShape, originalType.getElementType());
606 }
607 
608 /// Returns the reassociation maps to use in the `linalg.tensor_expand_shape`
609 /// operation to convert the operands of the original operation to operands of
610 /// the expanded operation. The same method is used to compute the
611 /// `linalg.tensor_collapse_shape` used to collapse the result of the expanded
612 /// op to get the value that can replace all uses of the results of the original
613 /// op.
614 static SmallVector<ReassociationIndices>
615 getReassociationForExpansion(AffineMap indexingMap,
616                              const ExpansionInfo &expansionInfo) {
617   SmallVector<ReassociationIndices> reassociation;
618   unsigned numReshapeDims = 0;
619   for (AffineExpr expr : indexingMap.getResults()) {
620     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
621     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
622     SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
623         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
624     reassociation.emplace_back(std::move(indices));
625     numReshapeDims += numExpandedDims;
626   }
627   return reassociation;
628 }
629 
630 /// Update the body of an expanded linalg operation having index semantics. The
631 /// indices of the original operation need to be recovered by linearizing the
632 /// indices of the correspoding dimensions of the expanded operation. For now it
633 /// is assumed that the shapes of the expanded operation needed for
634 /// linearization are static.
635 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
636                                           Location loc, Region &fusedRegion,
637                                           const ExpansionInfo &expansionInfo) {
638   // Replace the original indices by the linearization of the expanded indices.
639   for (IndexOp indexOp :
640        llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
641     ArrayRef<int64_t> expandedDims =
642         expansionInfo.getExpandedDims(indexOp.dim());
643     assert(!expandedDims.empty() && "expected valid expansion info");
644 
645     // Skip index operations that are not affected by the expansion.
646     if (expandedDims.size() == 1 &&
647         expandedDims.front() == (int64_t)indexOp.dim())
648       continue;
649 
650     // Linearize the expanded indices of the original index dimension.
651     OpBuilder::InsertionGuard guard(rewriter);
652     rewriter.setInsertionPointAfter(indexOp);
653     ArrayRef<int64_t> expandedDimsShape =
654         expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
655     SmallVector<Value> expandedIndices;
656     expandedIndices.reserve(expandedDims.size() - 1);
657     llvm::transform(
658         expandedDims.drop_front(), std::back_inserter(expandedIndices),
659         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
660     Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
661     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
662       assert(!ShapedType::isDynamic(std::get<0>(it)));
663       AffineExpr idx, acc;
664       bindDims(rewriter.getContext(), idx, acc);
665       newIndex = rewriter.create<AffineApplyOp>(
666           indexOp.getLoc(), idx + acc * std::get<0>(it),
667           ValueRange{std::get<1>(it), newIndex});
668     }
669     rewriter.replaceOp(indexOp, newIndex);
670   }
671 }
672 
673 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op
674 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
675 /// that those conditions have been satisfied.
676 static Optional<SmallVector<Value>>
677 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
678                            OpOperand *fusableOpOperand,
679                            PatternRewriter &rewriter) {
680   assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
681          "preconditions for fuse operation failed");
682   // Check if reshape is expanding or collapsing.
683   auto expandingReshapeOp = dyn_cast<TensorExpandShapeOp>(*reshapeOp);
684   auto collapsingReshapeOp = dyn_cast<TensorCollapseShapeOp>(*reshapeOp);
685   bool isExpanding = (expandingReshapeOp != nullptr);
686   RankedTensorType expandedType = isExpanding
687                                       ? expandingReshapeOp.getResultType()
688                                       : collapsingReshapeOp.getSrcType();
689 
690   ExpansionInfo expansionInfo;
691   if (failed(expansionInfo.compute(
692           genericOp, fusableOpOperand,
693           isExpanding ? expandingReshapeOp.getReassociationMaps()
694                       : collapsingReshapeOp.getReassociationMaps(),
695           expandedType.getShape(), rewriter)))
696     return llvm::None;
697 
698   if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
699     return llvm::None;
700 
701   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
702       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
703         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
704       }));
705 
706   SmallVector<Value> expandedOpOperands;
707   expandedOpOperands.reserve(genericOp.getNumInputs());
708   for (OpOperand *opOperand : genericOp.getInputOperands()) {
709     if (opOperand == fusableOpOperand) {
710       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
711                                                : collapsingReshapeOp.src());
712       continue;
713     }
714     if (genericOp.isInputTensor(opOperand)) {
715       AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
716       RankedTensorType expandedOperandType =
717           getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
718                           indexingMap, expansionInfo);
719       if (expandedOperandType != opOperand->get().getType()) {
720         // Reshape the operand to get the right type.
721         SmallVector<ReassociationIndices> reassociation =
722             getReassociationForExpansion(indexingMap, expansionInfo);
723         expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>(
724             genericOp.getLoc(), expandedOperandType, opOperand->get(),
725             reassociation));
726         continue;
727       }
728     }
729     expandedOpOperands.push_back(opOperand->get());
730   }
731 
732   Location loc = genericOp.getLoc();
733   SmallVector<Value> outputs;
734   for (OpOperand *opOperand : genericOp.getOutputOperands()) {
735     AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
736     RankedTensorType expandedOutputType =
737         getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
738                         indexingMap, expansionInfo);
739     if (expandedOutputType != opOperand->get().getType()) {
740       SmallVector<ReassociationIndices> reassociation =
741           getReassociationForExpansion(indexingMap, expansionInfo);
742       outputs.push_back(rewriter.create<TensorExpandShapeOp>(
743           genericOp.getLoc(), expandedOutputType, opOperand->get(),
744           reassociation));
745     }
746   }
747 
748   // The iterator types of the expanded op are all parallel.
749   SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
750                                        getParallelIteratorTypeName());
751 
752   TypeRange resultTypes = ValueRange(outputs).getTypes();
753   auto fusedOp =
754       rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
755                                  /*inputs=*/expandedOpOperands, outputs,
756                                  expandedOpIndexingMaps, iteratorTypes);
757   Region &fusedRegion = fusedOp->getRegion(0);
758   Region &originalRegion = genericOp->getRegion(0);
759   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
760 
761   // Update the index accesses after the expansion.
762   updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
763 
764   // Reshape the result values to their original shape if this is a collapsing
765   // reshape folded into its consumer.
766   SmallVector<Value> resultVals;
767   for (OpResult opResult : genericOp->getOpResults()) {
768     int64_t resultNumber = opResult.getResultNumber();
769     if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
770       SmallVector<ReassociationIndices> reassociation =
771           getReassociationForExpansion(
772               genericOp.getTiedIndexingMap(
773                   genericOp.getOutputOperand(resultNumber)),
774               expansionInfo);
775       resultVals.push_back(rewriter.create<TensorCollapseShapeOp>(
776           genericOp.getLoc(), opResult.getType(),
777           fusedOp->getResult(resultNumber), reassociation));
778     } else {
779       resultVals.push_back(fusedOp->getResult(resultNumber));
780     }
781   }
782   // Assuming a single result.
783   return resultVals;
784 }
785 
786 namespace {
787 
788 /// Pattern to fold tensor_expand_shape op with its consumer by using the source
789 /// of the reshape op as the operand in the consumer (instead of the result of
790 /// the tensor_collapse_shape). The corresponding index map in the consumer
791 /// needs to be modified to linearize the folded dimension.
792 ///
793 /// For example,
794 ///
795 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
796 /// %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2], [3]]
797 ///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
798 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
799 ///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
800 ///        -> tensor<?x?x4x?xf32>
801 ///
802 /// can be folded into
803 ///
804 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
805 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
806 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
807 ///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
808 ///        -> tensor<?x?x4x?xf32>
809 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
810 struct FoldProducerReshapeOpByLinearization
811     : public OpRewritePattern<GenericOp> {
812   using OpRewritePattern<GenericOp>::OpRewritePattern;
813 
814   LogicalResult matchAndRewrite(GenericOp genericOp,
815                                 PatternRewriter &rewriter) const override {
816     if (!genericOp.hasTensorSemantics())
817       return failure();
818     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
819     for (auto en : llvm::enumerate(inputOperands)) {
820       auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
821       if (!reshapeOp)
822         continue;
823 
824       if (!isTensorReshapeOpFoldableByLinearization(
825               reshapeOp, genericOp.getTiedIndexingMap(en.value()),
826               /*asProducer =*/true) ||
827           (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
828         continue;
829 
830       // Compute the fused operands list,
831       SmallVector<Value> fusedOperands = genericOp.getInputOperands();
832       fusedOperands[en.index()] = reshapeOp.src();
833       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
834       llvm::append_range(fusedOperands, outputOperands);
835 
836       // Compute indexing_maps for the fused operation. The indexing_maps for
837       // the operands of the consumers that arent fused are the same.
838       SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
839 
840       // Accepted consumer maps are either identity or permutation.
841       auto invMap = inversePermutation(fusedIndexMaps[en.index()]);
842 
843       // Compute the indexing map to use for the result of the producer.
844       AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp);
845       // The modified map cannot have symbols.
846       if (modifiedMap.getNumSymbols())
847         return failure();
848       for (AffineExpr expr : modifiedMap.getResults()) {
849         if (!expr.isPureAffine())
850           return failure();
851       }
852       fusedIndexMaps[en.index()] = modifiedMap;
853 
854       // Further check that the resulting index maps can be fused and
855       // inverted. Without this the resultant op is not legal.
856       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
857         return rewriter.notifyMatchFailure(
858             genericOp, "fused op loop bound computation failed");
859       }
860 
861       rewriter.startRootUpdate(genericOp);
862       genericOp->setOperands(fusedOperands);
863       genericOp.indexing_mapsAttr(
864           rewriter.getAffineMapArrayAttr(fusedIndexMaps));
865       rewriter.finalizeRootUpdate(genericOp);
866       return success();
867     }
868     return failure();
869   }
870 };
871 
872 static SmallVector<ReassociationIndices>
873 getReassociationIndices(ArrayRef<AffineMap> maps) {
874   SmallVector<ReassociationIndices> reassociation;
875   for (AffineMap map : maps) {
876     ReassociationIndices indices;
877     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
878       unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
879       indices.push_back(pos);
880     }
881     reassociation.push_back(indices);
882   }
883   return reassociation;
884 }
885 
886 /// Pattern to move rank reducing reshape after an elementwise linalg generic
887 /// op. This is useful to expose more fusion opportunities between named ops and
888 /// generic ops. This can only be done if there is no broadcast or permuation
889 /// within the dimensions we need to merge.
890 ///
891 /// For example,
892 ///
893 ///  %0 = linalg.tensor_expand_shape %A [[0, 1], [2]]
894 ///      : tensor<12544x16xf32> into tensor<112x112x16xf32>
895 ///  %2 = linalg.generic {indexing_maps = [
896 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
897 ///    affine_map<(d0, d1, d2) -> (d2)>,
898 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
899 ///    ["parallel", "parallel", "parallel"]} {
900 ///  } -> tensor<112x112x16xf32>
901 ///
902 ///  into
903 ///
904 ///  %2 = linalg.generic {indexing_maps = [
905 ///    affine_map<(d0, d1) -> (d0, d1)>,
906 ///    affine_map<(d0, d1) -> (d1)>,
907 ///    affine_map<(d0, d1) -> (d0, d1)>],
908 ///    iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
909 ///    : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
910 ///  } -> tensor<12544x16xf32>
911 ///  %3 = linalg.tensor_expand_shape %2 [[0, 1], [2]]
912 ///    : tensor<12544x16xf32> into tensor<112x112x16xf32>
913 struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
914   using OpRewritePattern<GenericOp>::OpRewritePattern;
915 
916   LogicalResult matchAndRewrite(GenericOp genericOp,
917                                 PatternRewriter &rewriter) const override {
918     // Only apply to elementwise linalg on tensor.
919     if (!genericOp.hasTensorSemantics() ||
920         genericOp.getNumParallelLoops() != genericOp.getNumLoops())
921       return failure();
922     // Only support identity output maps. It could be extended to permuations if
923     // needed.
924     if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
925           return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
926         }))
927       return failure();
928     int64_t destRank = genericOp.getNumParallelLoops();
929     SmallVector<Value> newOperands = genericOp.getInputOperands();
930     TensorExpandShapeOp reshapeFound;
931     // 1. Look for tensor_expand_shape operands and figure out save the
932     // dimensions merged.
933     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
934     for (auto en : llvm::enumerate(inputOperands)) {
935       auto reshapeOp =
936           en.value()->get().template getDefiningOp<TensorExpandShapeOp>();
937       if (!reshapeOp)
938         continue;
939       // TODO: We could support non-identity map as long as the merged
940       // dimensions are still contiguous.
941       if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
942         continue;
943       if (reshapeFound) {
944         // Only support a second reshape op if it has the same reassociate maps.
945         if (reshapeFound.getReassociationMaps() ==
946             reshapeOp.getReassociationMaps())
947           newOperands[en.index()] = reshapeOp.src();
948         continue;
949       }
950       reshapeFound = reshapeOp;
951       newOperands[en.index()] = reshapeOp.src();
952     }
953     if (!reshapeFound)
954       return failure();
955 
956     // Calculate the reassociation indices and rassociated reverse map.
957     SmallVector<ReassociationIndices> reassociation =
958         getReassociationIndices(reshapeFound.getReassociationMaps());
959     SmallVector<unsigned> remap(destRank);
960     for (auto &indices : llvm::enumerate(reassociation)) {
961       for (int64_t index : indices.value()) {
962         remap[index] = indices.index();
963       }
964     }
965     // 2. Verify that we can merge the dimensions in the linalg and that we
966     // don't need to create new reshapes operands. Inserting new reshape
967     // operands would defeat the purpose of the transformation.
968     for (auto en : llvm::enumerate(inputOperands)) {
969       if (en.value()->get() == newOperands[en.index()]) {
970         AffineMap map = genericOp.getTiedIndexingMap(en.value());
971         for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
972           if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
973             return failure();
974         }
975       }
976     }
977 
978     // 3. Calculate the affine map remapping and the reassociation to apply to
979     // output tensors.
980     SmallVector<AffineMap> newMaps;
981     unsigned newRank = reassociation.size();
982     for (auto map : genericOp.getIndexingMaps()) {
983       SmallVector<AffineExpr> newExprs;
984       for (auto expr : map.getResults()) {
985         unsigned position = expr.template cast<AffineDimExpr>().getPosition();
986         // Skip dimension merged except for the last of the group.
987         if (reassociation[remap[position]].back() == position) {
988           newExprs.push_back(
989               getAffineDimExpr(remap[position], genericOp.getContext()));
990         }
991       }
992       newMaps.push_back(
993           AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
994     }
995 
996     // 4. Reshape the output tensors.
997     SmallVector<Value> newOutputs;
998     SmallVector<Type> newOutputTypes;
999     for (auto output : genericOp.outputs()) {
1000       auto newOutputType = RankedTensorType::get(
1001           reshapeFound.getSrcType().getShape(),
1002           output.getType().template cast<RankedTensorType>().getElementType());
1003       Value newOutput = rewriter.create<TensorCollapseShapeOp>(
1004           genericOp->getLoc(), newOutputType, output, reassociation);
1005       newOutputTypes.push_back(newOutputType);
1006       newOutputs.push_back(newOutput);
1007     }
1008     // 5. Create a new generic op with lowerer rank.
1009     SmallVector<StringRef> iteratorTypes(newRank,
1010                                          getParallelIteratorTypeName());
1011     auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
1012                                             newOperands, newOutputs, newMaps,
1013                                             iteratorTypes);
1014     rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
1015                                 newOp.region().begin());
1016     // 6. Reshape the so that the type matches the uses.
1017     SmallVector<Value> newResults;
1018     for (auto result : llvm::enumerate(newOp->getResults())) {
1019       newResults.push_back(rewriter.create<TensorExpandShapeOp>(
1020           genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
1021           result.value(), reassociation));
1022     }
1023     rewriter.replaceOp(genericOp, newResults);
1024     return success();
1025   }
1026 };
1027 
1028 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
1029 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
1030 /// in the consumer is expanded.
1031 class FoldWithProducerReshapeOpByExpansion
1032     : public OpRewritePattern<GenericOp> {
1033 public:
1034   FoldWithProducerReshapeOpByExpansion(
1035       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1036       PatternBenefit benefit = 1)
1037       : OpRewritePattern<GenericOp>(context, benefit),
1038         controlFoldingReshapes(foldReshapes) {}
1039 
1040   LogicalResult matchAndRewrite(GenericOp genericOp,
1041                                 PatternRewriter &rewriter) const override {
1042     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1043       TensorCollapseShapeOp reshapeOp =
1044           opOperand->get().getDefiningOp<TensorCollapseShapeOp>();
1045       if (!reshapeOp)
1046         continue;
1047       // Fold only if
1048       // - The tensor reshape op is folding.
1049       // - All constraints of fusing with reshape by expansion are met.
1050       if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
1051           (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
1052         continue;
1053 
1054       Optional<SmallVector<Value>> replacementValues =
1055           fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
1056       if (!replacementValues)
1057         return failure();
1058       rewriter.replaceOp(genericOp, replacementValues.getValue());
1059       return success();
1060     }
1061     return failure();
1062   }
1063 
1064 private:
1065   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1066 };
1067 
1068 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
1069 /// producer. The corresponding index map in the consumer needs to be modified
1070 /// to linearize the folded dimension.
1071 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
1072 struct FoldConsumerReshapeOpByLinearization
1073     : public OpRewritePattern<TensorReshapeOp> {
1074   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1075 
1076   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1077                                 PatternRewriter &rewriter) const override {
1078     GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
1079     if (!producer || !producer.hasTensorSemantics() ||
1080         producer.getNumOutputs() != 1 ||
1081         !isTensorReshapeOpFoldableByLinearization(
1082             reshapeOp,
1083             producer.getTiedIndexingMap(producer.getOutputOperand(0)),
1084             /*asProducer =*/false) ||
1085         (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
1086       return failure();
1087     // The indexing_maps for the operands of the fused operation are same as
1088     // those for the operands of the producer.
1089     SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
1090 
1091     auto invMap = inversePermutation(
1092         producer.getTiedIndexingMap(producer.getOutputOperand(0)));
1093 
1094     // Compute the indexing map to use for the operand of the producer.
1095     AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp);
1096     for (AffineExpr expr : modifiedMap.getResults()) {
1097       if (!expr.isPureAffine()) {
1098         return rewriter.notifyMatchFailure(
1099             producer, "fused op indexing map is not affine");
1100       }
1101     }
1102     fusedIndexMaps.back() = modifiedMap;
1103 
1104     // Further check that the resulting index maps can be fused and
1105     // inverted. Without this the resultant op is not legal.
1106     if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1107       return rewriter.notifyMatchFailure(
1108           producer, "fused op loop bound computation failed");
1109     }
1110 
1111     Location loc = producer.getLoc();
1112     SmallVector<Value> inputOperands = producer.getInputOperands();
1113     Value output = rewriter.create<TensorReshapeOp>(
1114         loc, producer.getOutputOperand(0)->get(),
1115         reshapeOp.getReassociationExprs());
1116     auto fusedOp = rewriter.create<GenericOp>(
1117         loc, reshapeOp.getResultType(),
1118         /*inputs=*/inputOperands,
1119         // TODO: handle outputs.
1120         /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1121         producer.iterator_types(),
1122         /*doc=*/nullptr,
1123         /*library_call=*/nullptr);
1124     auto &fusedRegion = fusedOp->getRegion(0);
1125     rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
1126                                fusedRegion.begin());
1127     rewriter.replaceOp(reshapeOp, fusedOp->getResults());
1128     return success();
1129   }
1130 };
1131 
1132 /// Pattern to fold a tensor_expand_shape op with its producer generic op
1133 /// by expanding the dimensionality of the loop in the producer op.
1134 struct FoldReshapeWithGenericOpByExpansion
1135     : public OpRewritePattern<TensorExpandShapeOp> {
1136   using OpRewritePattern<TensorExpandShapeOp>::OpRewritePattern;
1137   LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp,
1138                                 PatternRewriter &rewriter) const override {
1139     // Fold only if all constraints of fusing with reshape by expansion are met.
1140     GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
1141     if (!producer || producer.getNumOutputs() != 1 ||
1142         !isFusableWithReshapeByDimExpansion(producer,
1143                                             producer.getOutputOperand(0)) ||
1144         isUnitDimExpansionOnly(reshapeOp))
1145       return failure();
1146     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
1147         producer, reshapeOp, producer.getOutputOperand(0), rewriter);
1148     if (!replacementValues)
1149       return failure();
1150     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
1151     return success();
1152   }
1153 };
1154 
1155 /// Pattern to fold a generic op with a splat constant.
1156 class FoldSplatConstants : public OpRewritePattern<GenericOp> {
1157 public:
1158   FoldSplatConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
1159                      PatternBenefit benefit = 1)
1160       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
1161 
1162   LogicalResult matchAndRewrite(GenericOp genericOp,
1163                                 PatternRewriter &rewriter) const override {
1164     if (!genericOp.hasTensorSemantics())
1165       return failure();
1166     for (OpOperand *opOperand : genericOp.getInputOperands()) {
1167       Operation *def = opOperand->get().getDefiningOp();
1168       DenseElementsAttr constantAttr;
1169       if (!def ||
1170           !matchPattern(def, m_Constant<DenseElementsAttr>(&constantAttr)) ||
1171           !constantAttr.isSplat() || !controlFn(def->getResult(0), *opOperand))
1172         continue;
1173 
1174       // The operands and the indexing_maps of the fused operation the same as
1175       // the operands and indexing_maps of the generic operations with the
1176       // values at the constant index dropped.
1177       SmallVector<AffineMap> fusedIndexMaps;
1178       SmallVector<Value> fusedOperands;
1179       fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1180       fusedOperands.reserve(genericOp.getNumInputs());
1181       for (OpOperand *inputOperand : genericOp.getInputOperands()) {
1182         if (inputOperand == opOperand)
1183           continue;
1184         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1185         fusedOperands.push_back(inputOperand->get());
1186       }
1187       for (OpOperand *outputOperand : genericOp.getOutputOperands())
1188         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1189 
1190       // Check if the operation shapes to loops map is computable.
1191       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1192         return rewriter.notifyMatchFailure(
1193             genericOp, "fused op loop bound computation failed");
1194       }
1195 
1196       // Create a constant scalar value from the splat constant.
1197       Value scalarConstant = rewriter.create<ConstantOp>(
1198           def->getLoc(), constantAttr.getSplatValue(),
1199           constantAttr.getType().getElementType());
1200 
1201       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1202       auto fusedOp = rewriter.create<GenericOp>(
1203           rewriter.getUnknownLoc(), genericOp->getResultTypes(),
1204           /*inputs=*/fusedOperands,
1205           /*outputs=*/outputOperands,
1206           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1207           genericOp.iterator_types(),
1208           /*doc=*/nullptr,
1209           /*library_call=*/nullptr);
1210 
1211       // Map the block argument corresponding to the replaced argument with the
1212       // scalar constant.
1213       Region &region = genericOp->getRegion(0);
1214       Block &entryBlock = *region.begin();
1215       BlockAndValueMapping mapping;
1216       mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1217                   scalarConstant);
1218       Region &fusedRegion = fusedOp->getRegion(0);
1219       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1220                                  mapping);
1221       rewriter.replaceOp(genericOp, fusedOp->getResults());
1222       return success();
1223     }
1224     return failure();
1225   }
1226 
1227 private:
1228   ControlElementwiseOpsFusionFn controlFn;
1229 };
1230 } // namespace
1231 
1232 static Optional<SmallVector<Value>>
1233 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
1234                    GenericOp producer,
1235                    const ControlElementwiseOpsFusionFn &controlFn) {
1236   if (producer->getNumResults() != 1)
1237     return llvm::None;
1238 
1239   return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
1240                                 rewriter);
1241 }
1242 
1243 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
1244                                       OpOperand &consumer) {
1245   auto expandShapeOp = producer.getDefiningOp<linalg::TensorExpandShapeOp>();
1246   if (expandShapeOp)
1247     return !isUnitDimExpansionOnly(expandShapeOp);
1248   auto collapseShapeOp =
1249       producer.getDefiningOp<linalg::TensorCollapseShapeOp>();
1250   return !isUnitDimExpansionOnly(collapseShapeOp);
1251 }
1252 
1253 namespace {
1254 /// Patterns to fuse a generic op, with the producer of its operands.
1255 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
1256 public:
1257   FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
1258                      PatternBenefit benefit = 1)
1259       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
1260 
1261   LogicalResult matchAndRewrite(GenericOp genericOp,
1262                                 PatternRewriter &rewriter) const override {
1263     // Find the first operand that is defined by another generic op on tensors.
1264     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
1265       auto producer =
1266           dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
1267       if (!producer || !producer.hasTensorSemantics())
1268         continue;
1269       Optional<SmallVector<Value>> fusedOpResults =
1270           fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
1271       if (fusedOpResults) {
1272         rewriter.replaceOp(genericOp, *fusedOpResults);
1273         return success();
1274       }
1275     }
1276     return failure();
1277   }
1278 
1279 private:
1280   ControlElementwiseOpsFusionFn controlFn;
1281 };
1282 
1283 /// Pass that fuses generic ops on tensors. Used only for testing.
1284 struct LinalgElementwiseOpFusionPass
1285     : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
1286   void runOnOperation() override {
1287     Operation *op = getOperation();
1288     RewritePatternSet patterns(op->getContext());
1289     ControlElementwiseOpsFusionFn allowFoldingFn =
1290         [](const OpResult &producer, const OpOperand &consumer) {
1291           return true;
1292         };
1293     populateElementwiseOpsFusionPatterns(
1294         patterns,
1295         LinalgElementwiseFusionOptions().setControlFoldingReshapes(
1296             allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
1297 
1298     // Use TopDownTraversal for compile time reasons
1299     GreedyRewriteConfig grc;
1300     grc.useTopDownTraversal = true;
1301     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
1302                                        grc);
1303   }
1304 };
1305 
1306 /// Pass to test folding of reshape ops with generic ops by linearization.
1307 struct FoldReshapeOpsByLinearizationPass
1308     : public LinalgFoldReshapeOpsByLinearizationBase<
1309           FoldReshapeOpsByLinearizationPass> {
1310   void runOnOperation() override {
1311     Operation *op = getOperation();
1312     RewritePatternSet patterns(op->getContext());
1313     populateFoldReshapeOpsByLinearizationPatterns(patterns);
1314     if (allowFoldingUnitDimReshapes) {
1315       populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
1316     }
1317     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
1318   }
1319 };
1320 
1321 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
1322 /// the value of the `outs` operand is not used within the op.  This is only
1323 /// implemented for `linalg.generic` operations for now, but should hold for all
1324 /// linalg structured ops.
1325 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1326   using OpRewritePattern<GenericOp>::OpRewritePattern;
1327 
1328   LogicalResult matchAndRewrite(GenericOp op,
1329                                 PatternRewriter &rewriter) const override {
1330     rewriter.startRootUpdate(op);
1331     bool modifiedOutput = false;
1332     Location loc = op.getLoc();
1333     for (OpOperand *opOperand : op.getOutputOperands()) {
1334       if (!op.payloadUsesValueFromOperand(opOperand)) {
1335         Value operandVal = opOperand->get();
1336         auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
1337         if (!operandType)
1338           continue;
1339 
1340         // If outs is already an `init_tensor` operation, nothing to do.
1341         auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
1342         if (definingOp)
1343           continue;
1344         modifiedOutput = true;
1345         SmallVector<Value> dynamicDims;
1346         for (auto dim : llvm::enumerate(operandType.getShape())) {
1347           if (dim.value() != ShapedType::kDynamicSize)
1348             continue;
1349           dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
1350               loc, operandVal, dim.index()));
1351         }
1352         Value initTensor = rewriter.create<InitTensorOp>(
1353             loc, dynamicDims, operandType.getShape(),
1354             operandType.getElementType());
1355         op->setOperand(opOperand->getOperandNumber(), initTensor);
1356       }
1357     }
1358     if (!modifiedOutput) {
1359       rewriter.cancelRootUpdate(op);
1360       return failure();
1361     }
1362     rewriter.finalizeRootUpdate(op);
1363     return success();
1364   }
1365 };
1366 
1367 } // namespace
1368 
1369 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
1370     RewritePatternSet &patterns) {
1371   patterns
1372       .add<FoldProducerReshapeOpByLinearization<false, TensorCollapseShapeOp>,
1373            FoldProducerReshapeOpByLinearization<false, TensorExpandShapeOp>,
1374            FoldConsumerReshapeOpByLinearization<false, TensorCollapseShapeOp>,
1375            FoldConsumerReshapeOpByLinearization<false, TensorExpandShapeOp>>(
1376           patterns.getContext());
1377 }
1378 
1379 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
1380     RewritePatternSet &patterns) {
1381   patterns
1382       .add<FoldProducerReshapeOpByLinearization<true, TensorCollapseShapeOp>,
1383            FoldProducerReshapeOpByLinearization<true, TensorExpandShapeOp>,
1384            FoldConsumerReshapeOpByLinearization<true, TensorCollapseShapeOp>,
1385            FoldConsumerReshapeOpByLinearization<true, TensorExpandShapeOp>>(
1386           patterns.getContext());
1387 }
1388 
1389 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
1390     RewritePatternSet &patterns,
1391     ControlElementwiseOpsFusionFn controlFoldingReshapes) {
1392   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
1393   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
1394                                                      controlFoldingReshapes);
1395 }
1396 
1397 void mlir::linalg::populateElementwiseOpsFusionPatterns(
1398     RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
1399   auto *context = patterns.getContext();
1400   patterns.add<FuseElementwiseOps, FoldSplatConstants>(
1401       context, options.controlElementwiseOpsFusionFn);
1402   patterns.add<RemoveOutsDependency>(context);
1403   populateFoldReshapeOpsByExpansionPatterns(patterns,
1404                                             options.controlFoldingReshapesFn);
1405   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1406   GenericOp::getCanonicalizationPatterns(patterns, context);
1407   TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1408   TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1409   context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1410       patterns);
1411 }
1412 
1413 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
1414   auto *context = patterns.getContext();
1415   patterns.add<PushExpandingReshape>(context);
1416 }
1417 
1418 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
1419   return std::make_unique<LinalgElementwiseOpFusionPass>();
1420 }
1421 
1422 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
1423   return std::make_unique<FoldReshapeOpsByLinearizationPass>();
1424 }
1425