1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "PassDetail.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Support/LLVM.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 
26 using namespace mlir;
27 using namespace mlir::linalg;
28 
29 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
30 /// the `producer` to use in the fused operation given the indexing map of the
31 /// result of the producer in the consumer.
32 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
33     OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
34     AffineMap fusedConsumerArgIndexMap) {
35   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
36   // from consumer loop -> consumer arg tensor index/producer result tensor
37   // index. The fused loop is same as the consumer loop. For each producer arg
38   // the indexing map to be computed is a map from consumer loop -> producer
39   // arg tensor index.
40   // producerResultIndexMap is a map from producer loop -> tensor index.
41   // Compute the inverse to get map from tensor index -> producer loop.
42   // The inverse is a map from producer result tensor index -> producer loop.
43   AffineMap invProducerResultIndexMap =
44       inversePermutation(producerResultIndexMap);
45   assert(invProducerResultIndexMap &&
46          "expected producer result indexig map to be invertible");
47 
48   LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
49   // argMap is a map from producer loop -> producer arg tensor index.
50   AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
51 
52   // Compose argMap with invProducerResultIndexMap to get a map from
53   // producer result tensor index -> producer arg tensor index.
54   AffineMap t1 = argMap.compose(invProducerResultIndexMap);
55 
56   // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
57   // consumer loop/ fused loop -> producer arg tensor index.
58   return t1.compose(fusedConsumerArgIndexMap);
59 }
60 
61 /// Conditions for elementwise fusion of generic operations.
62 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
63                                      OpOperand *consumerOpOperand) {
64   // Producer and consumer must have tensor semantics.
65   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
66     return false;
67 
68   // Verify that
69   // - the producer has all "parallel" iterator type.
70   if (producer.getNumParallelLoops() != producer.getNumLoops())
71     return false;
72 
73   // Only allow fusing the producer of an input operand for now.
74   // TODO: allow fusing the producer of an output operand.
75   if (!consumer.isInputTensor(consumerOpOperand))
76     return false;
77 
78   // Get the consumer index map. The number of results of the consumer index
79   // map must match the number of loops of the producer.
80   AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
81   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
82     return false;
83 
84   // Currently support only operations with single result.
85   if (producer.getNumOutputs() != 1)
86     return false;
87 
88   // Finally the index_map for the result must be invertible. For now just
89   // verify it is a permutation.
90   AffineMap producerResultIndexMap =
91       producer.getTiedIndexingMap(producer.getOutputOperand(0));
92   if (!producerResultIndexMap.isPermutation())
93     return false;
94 
95   // Ensure that the fusion does not remove size information required to
96   // get the loop bounds. For non-reduction generics, this is trivially the
97   // case due to the output operand. For reductions, we need to check that after
98   // the fusion, each loop dimension has at least one input that defines it.
99   if ((consumer.getNumReductionLoops())) {
100     llvm::BitVector coveredDims(consumer.getNumLoops(), false);
101 
102     auto addToCoveredDims = [&](AffineMap map) {
103       for (auto result : map.getResults())
104         if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
105           coveredDims[dimExpr.getPosition()] = true;
106     };
107 
108     for (auto pair :
109          llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) {
110       Value operand = std::get<0>(pair);
111       if (operand == consumerOpOperand->get())
112         continue;
113       AffineMap operandMap = std::get<1>(pair);
114       addToCoveredDims(operandMap);
115     }
116 
117     for (OpOperand *operand : producer.getInputOperands()) {
118       AffineMap newIndexingMap =
119           getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
120               operand, producerResultIndexMap, consumerIndexMap);
121       addToCoveredDims(newIndexingMap);
122     }
123     if (!coveredDims.all())
124       return false;
125   }
126 
127   return true;
128 }
129 
130 /// Generate the region of the fused tensor operation. The region of the fused
131 /// op must be empty.
132 static void
133 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
134                                  AffineMap consumerToProducerLoopsMap,
135                                  OpOperand *consumerOpOperand,
136                                  unsigned nloops) {
137   auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
138   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
139   // Build the region of the fused op.
140   Block &producerBlock = producer->getRegion(0).front();
141   Block &consumerBlock = consumer->getRegion(0).front();
142   Block *fusedBlock = new Block();
143   fusedOp.region().push_back(fusedBlock);
144   BlockAndValueMapping mapper;
145   OpBuilder::InsertionGuard guard(rewriter);
146   rewriter.setInsertionPointToStart(fusedBlock);
147 
148   // 2. Add an index operation for every fused loop dimension and use the
149   // `consumerToProducerLoopsMap` to map the producer indices.
150   if (producer.hasIndexSemantics()) {
151     // Add an index operation for every fused loop dimension.
152     unsigned numFusedOpLoops =
153         std::max(producer.getNumLoops(), consumer.getNumLoops());
154     SmallVector<Value> fusedIndices;
155     fusedIndices.reserve(numFusedOpLoops);
156     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
157                     std::back_inserter(fusedIndices), [&](uint64_t dim) {
158                       return rewriter.create<IndexOp>(producer.getLoc(), dim);
159                     });
160     for (IndexOp indexOp :
161          llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
162       Value newIndex = rewriter.create<mlir::AffineApplyOp>(
163           producer.getLoc(),
164           consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
165       mapper.map(indexOp.getResult(), newIndex);
166     }
167   }
168   // TODO: allow fusing the producer of an output operand.
169   assert(consumer.isInputTensor(consumerOpOperand) &&
170          "expected producer of input operand");
171   // 3. Consumer input operands up to consumerIdx (exclusive).
172   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
173            consumerOpOperand->getOperandNumber())) // input assumption.
174     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
175 
176   // Replacing consumerIdx requires getting the cloned, yielded, value from
177   // the (cloned) producer block. This happens in step 9.
178 
179   // 4. Splice in producer's input operands.
180   for (BlockArgument bbArg :
181        producerBlock.getArguments().take_front(producer.getNumInputs()))
182     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
183 
184   // 4.b. Producer output operand/map that is fused needs to be mapped to the
185   // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
186   assert(producer->getNumResults() == 1 && "expected single result producer");
187   if (producer.isInitTensor(producer.getOutputOperand(0))) {
188     BlockArgument bbArg = producerBlock.getArguments()
189                               .drop_front(producer.getNumInputs())
190                               // TODO: bbArg index of
191                               .front();
192     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
193   }
194   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
195   for (BlockArgument bbArg :
196        consumerBlock.getArguments()
197            .take_front(consumer.getNumInputs())
198            .drop_front(consumerOpOperand->getOperandNumber() + 1))
199     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
200   // 6. All of consumer's output operands.
201   for (BlockArgument bbArg :
202        consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
203     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
204   // 7. All of producer's output operands except the one fused.
205   // TODO: allow fusion of multi-result producers.
206   assert(producer->getNumResults() == 1 && "expected single result producer");
207 
208   // 8. Clone all producer operations except for the yield and index operations
209   // to the fused operation.
210   for (auto &op : producerBlock.without_terminator()) {
211     if (!isa<IndexOp>(op))
212       rewriter.clone(op, mapper);
213   }
214   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
215   // forward the yield operand.
216   auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
217   // TODO: allow fusion of multi-result producers.
218   assert(producer->getNumResults() == 1 && "expected single result producer");
219   unsigned producerResultNumber = 0;
220   Value replacement =
221       mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
222   // Sanity checks, if replacement is not already in the mapper then it must be
223   // produced outside.
224   if (replacement == yieldOp.getOperand(producerResultNumber)) {
225     if (auto bb = replacement.dyn_cast<BlockArgument>())
226       assert(bb.getOwner() != &producerBlock &&
227              "yielded block argument must have been mapped");
228     else
229       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
230              "yielded value must have been mapped");
231   }
232   mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
233              replacement);
234   // 10. Clone operations from the consumer to the fused op.
235   for (auto &op : consumerBlock.getOperations())
236     rewriter.clone(op, mapper);
237 
238   // Sanity checks.
239   assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
240          "Ill-formed GenericOp region");
241 }
242 
243 static Optional<SmallVector<Value>>
244 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
245                        const ControlElementwiseOpsFusionFn &controlFn,
246                        PatternRewriter &rewriter) {
247   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
248   if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
249       !controlFn(producer->getResult(0), *consumerOpOperand))
250     return llvm::None;
251 
252   // TODO: allow fusing the producer of an output operand.
253   assert(consumer.isInputTensor(consumerOpOperand) &&
254          "expected producer of input operand");
255 
256   // Compute the fused operands list and indexing maps.
257   SmallVector<Value> fusedOperands;
258   SmallVector<AffineMap> fusedIndexMaps;
259   fusedOperands.reserve(producer->getNumOperands() +
260                         consumer->getNumOperands());
261   fusedIndexMaps.reserve(producer->getNumOperands() +
262                          consumer->getNumOperands());
263   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
264   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
265   SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
266   SmallVector<OpOperand *>::iterator it =
267       llvm::find(consumerInputs, consumerOpOperand);
268   assert(it != consumerInputs.end() && "expected to find the consumer operand");
269   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
270     fusedOperands.push_back(opOperand->get());
271     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
272   }
273   // 4. Splice in producer's input operands/maps.
274   assert(producer->getNumResults() == 1 && "expected single result producer");
275   AffineMap producerResultIndexMap =
276       producer.getTiedIndexingMap(producer.getOutputOperand(0));
277   for (OpOperand *opOperand : producer.getInputOperands()) {
278     fusedOperands.push_back(opOperand->get());
279     // Compute indexing maps for the producer args in the fused operation.
280     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
281         opOperand, producerResultIndexMap,
282         consumer.getTiedIndexingMap(consumerOpOperand));
283     fusedIndexMaps.push_back(map);
284   }
285   // 4.b. Producer output operand/map that is fused needs to be passed if it is
286   // an "initTensor" (i.e. its value is actually read).
287   assert(producer->getNumResults() == 1 && "expected single result producer");
288   if (producer.isInitTensor(producer.getOutputOperand(0))) {
289     fusedOperands.push_back(producer.getOutputOperand(0)->get());
290     // Compute indexing maps for the producer args in the fused operation.
291     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
292         producer.getOutputOperand(0), producerResultIndexMap,
293         consumer.getTiedIndexingMap(consumerOpOperand));
294     fusedIndexMaps.push_back(map);
295   }
296   // 5. Remaining consumer's input operands/maps (drop past index
297   // `consumerIdx`).
298   for (OpOperand *opOperand :
299        llvm::make_range(std::next(it), consumerInputs.end())) {
300     fusedOperands.push_back(opOperand->get());
301     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
302   }
303   // 6. All of consumer's output operands (skip operands: added by the builder).
304   for (OpOperand *opOperand : consumer.getOutputOperands())
305     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
306   // 7. All of producer's output operands/maps except the one fused.
307   // TODO: allow fusion of multi-result producers.
308   assert(producer->getNumResults() == 1 && "expected single result producer");
309 
310   // Generate the fused op.
311   SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
312   auto fusedOp = rewriter.create<GenericOp>(
313       consumer.getLoc(), consumer->getResultTypes(),
314       /*inputs=*/fusedOperands,
315       // TODO: handle outputs.
316       consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
317       consumer.iterator_types(),
318       /*doc=*/nullptr,
319       /*library_call=*/nullptr);
320 
321   // Construct an AffineMap from consumer loops to producer loops.
322   // consumer loop -> tensor index
323   AffineMap consumerResultIndexMap =
324       consumer.getTiedIndexingMap(consumerOpOperand);
325   // tensor index -> producer loop
326   AffineMap invProducerResultIndexMap =
327       inversePermutation(producerResultIndexMap);
328   assert(invProducerResultIndexMap &&
329          "expected producer result indexig map to be invertible");
330   // consumer loop -> producer loop
331   AffineMap consumerToProducerLoopsMap =
332       invProducerResultIndexMap.compose(consumerResultIndexMap);
333 
334   generateFusedElementwiseOpRegion(rewriter, fusedOp,
335                                    consumerToProducerLoopsMap,
336                                    consumerOpOperand, consumer.getNumLoops());
337   return SmallVector<Value>(fusedOp->getResults());
338 }
339 
340 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
341 /// provided, given the shape of the source tensor that corresponds to the
342 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
343 /// are "row-major" ordered logically.
344 ///
345 /// For example:
346 ///
347 /// %0 = op ... : tensor<?x?x4x5xf32>
348 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
349 ///
350 /// and reshape:
351 /// %1 = linalg.tensor_collapse_shape %0 [[0], [0, 1, 2]] :
352 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
353 ///
354 /// would be rewritten into:
355 /// %0 = op ... : tensor<?x?x4x5xf32>
356 /// with output index_map
357 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
358 template <typename TensorReshapeOp>
359 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
360                                         TensorReshapeOp reshapeOp) {
361   constexpr bool isExpanding =
362       std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value;
363   ArrayRef<int64_t> sourceShape =
364       (isExpanding ? reshapeOp.getResultType().getShape()
365                    : reshapeOp.getSrcType().getShape());
366   SmallVector<AffineExpr> resultExprs;
367   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
368   MLIRContext *context = sourceMap.getContext();
369 
370   // Compute the result exprs based on the reassociation maps.
371   for (auto &indices : reshapeOp.getReassociationIndices()) {
372     // Assume that they are in-order and contiguous (already checked in
373     // verifier).
374     assert(!indices.empty());
375     SmallVector<int64_t> sizes;
376     SmallVector<AffineExpr> dimExprs;
377     for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
378                              sourceExprs.slice(indices[0], indices.size()))) {
379       if (std::get<0>(en) == 1)
380         continue;
381       sizes.push_back(std::get<0>(en));
382       dimExprs.push_back(std::get<1>(en));
383     }
384     AffineExpr linearizedExpr =
385         makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
386     resultExprs.push_back(linearizedExpr);
387   }
388   // The new affine map cannot drop unused dimension but some new symbols may
389   // have been added. Create a map with at least as many dimensions/symbols as
390   // the original affine map.
391   int64_t maxDim = -1;
392   int64_t maxSym = -1;
393   getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym);
394   unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims());
395   unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols());
396   return AffineMap::get(numDims, numSyms, resultExprs, context);
397 }
398 
399 // TensorExpandShapeOp is fusable with its consumer (i.e. reshape as a
400 // producer). Fusing when operand has higher rank will require use of mods and
401 // divs in the indexing maps of the fused op which would make it non-invertible.
402 static bool isTensorReshapeOpFoldableByLinearization(
403     TensorExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
404   if (!asProducer)
405     return false;
406   return useIndexMap.isPermutation();
407 }
408 
409 // TensorCollapseShapeOp is fusable with its producer (i.e. reshape as a
410 // consumer).
411 static bool isTensorReshapeOpFoldableByLinearization(
412     TensorCollapseShapeOp collapseOp, AffineMap useIndexMap, bool asProducer) {
413   if (asProducer)
414     return false;
415   return useIndexMap.isPermutation();
416 }
417 
418 /// Check if the reshape operation is only expansion into/collapsing of
419 /// unit-dimension.
420 template <typename TensorReshapeOp>
421 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
422   constexpr bool isExpanding =
423       std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value;
424   ArrayRef<int64_t> expandedShape =
425       (isExpanding ? reshapeOp.getResultType().getShape()
426                    : reshapeOp.getSrcType().getShape());
427   for (auto &indices : reshapeOp.getReassociationIndices()) {
428     unsigned numUnitDims = 0;
429     for (int64_t position : indices)
430       if (expandedShape[position] == 1)
431         numUnitDims++;
432     if (numUnitDims != indices.size() - 1)
433       return false;
434   }
435   return true;
436 }
437 
438 /// Conditions for folding a generic operation with a reshape op by expanding
439 /// the iteration space dimensionality for tensor operations. These are
440 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
441 /// following fusion pattern.
442 ///
443 ///  Consider
444 ///
445 ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
446 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
447 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
448 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
449 ///  %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]]
450 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
451 ///
452 ///  The reshape can be folded into the `genericOp` if its loop dimensionality
453 ///  is increased to match the result (operand) of the tensor_expand_shape.
454 ///  The indexing_map of the fused tensor in the `genericOp` and the
455 ///  reassociation map helps compute the indexing maps of the modified op.
456 ///  For the above example, based on the reassociation map it
457 ///  can be concluded that
458 ///
459 ///  - The loop used to access the first dimension of the fused tensor is split
460 ///    into two.
461 ///  - The loop used to access the second dimension of the fused tensor is kept
462 ///    as is.
463 ///  - The loop used to access the third dimension of the fused tensor is split
464 ///    into three.
465 ///
466 ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
467 ///  op, then
468 ///
469 ///   d0 -> e0, e1
470 ///   d1 -> e2, e3, e4
471 ///   d2 -> e5
472 ///
473 ///  substituting this, the generic op can be rewritten as
474 ///
475 ///  %d = linalg.generic ins(%0, %1 : )
476 ///        indexing_maps =
477 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
478 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
479 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
480 ///
481 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
482 ///  to make it consistent
483 ///
484 ///  %0 = linalg.tensor_expand_shape %a [[0, 1, 2], [3, 4], [5]]
485 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
486 ///  %1 = linalg.tensor_expand_shape %b [[0, 1, 2], [3]]
487 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
488 ///
489 ///  The added reshapes are again expanding patterns, so they will get fused
490 ///  with its producers if possible.
491 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
492                                                OpOperand *fusableOpOperand) {
493   // Is fusable only if:
494   // - All the indexing maps for operands and results are projected
495   //   permutations.
496   // - The fused tensor is not a scalar.
497   // - All the loops are parallel loops.
498   return genericOp.hasTensorSemantics() &&
499          llvm::all_of(genericOp.indexing_maps().getValue(),
500                       [](Attribute attr) {
501                         return attr.cast<AffineMapAttr>()
502                             .getValue()
503                             .isProjectedPermutation();
504                       }) &&
505          genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
506          llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
507            return attr.cast<StringAttr>().getValue() ==
508                   getParallelIteratorTypeName();
509          });
510 }
511 
512 namespace {
513 /// Information needed to expand a generic operation to fold the reshape with
514 /// it.
515 class ExpansionInfo {
516 public:
517   // Computes the mapping from original dimensions of the op to the dimensions
518   // of the expanded op given the `indexingMap` of the fused operand/result of
519   // the generic op, the `reassocationMaps` of the reshape op and the shape of
520   // the expanded op.
521   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
522                         ArrayRef<AffineMap> reassociationMaps,
523                         ArrayRef<int64_t> expandedShape,
524                         PatternRewriter &rewriter);
525   unsigned getOrigOpNumDims() const { return reassociation.size(); }
526   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
527   ReassociationIndicesRef getExpandedDims(unsigned i) const {
528     return reassociation[i];
529   }
530   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
531     return expandedShapeMap[i];
532   }
533 
534 private:
535   /// Reassociation from the dimensions in the original operation to the
536   /// dimension of the expanded operation.
537   SmallVector<ReassociationIndices> reassociation;
538   /// Mapping from extent of loops in the original operation, to the extent of
539   /// loops in the expanded operation.
540   SmallVector<SmallVector<int64_t>> expandedShapeMap;
541   unsigned expandedOpNumDims;
542 };
543 } // namespace
544 
545 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
546                                      OpOperand *fusableOpOperand,
547                                      ArrayRef<AffineMap> reassociationMaps,
548                                      ArrayRef<int64_t> expandedShape,
549                                      PatternRewriter &rewriter) {
550   if (reassociationMaps.empty())
551     return failure();
552   AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
553 
554   Optional<SmallVector<int64_t, 4>> originalLoopRange =
555       linalgOp.getStaticLoopRanges();
556   if (!originalLoopRange)
557     return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
558 
559   reassociation.clear();
560   expandedShapeMap.clear();
561   // Compute the number of dimension in the expanded op that correspond to each
562   // dimension of the original op.
563   SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
564   expandedShapeMap.resize(fusedIndexMap.getNumDims());
565   for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
566     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
567     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
568     numExpandedDims[pos] = foldedDims.getNumResults();
569     ArrayRef<int64_t> shape =
570         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
571     expandedShapeMap[pos].assign(shape.begin(), shape.end());
572   }
573   // The remaining dimensions remain the same.
574   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
575     if (expandedShapeMap[i].empty())
576       expandedShapeMap[i] = {(*originalLoopRange)[i]};
577 
578   // Compute reassociation map from the original op to the expanded op.
579   unsigned sum = 0;
580   reassociation.reserve(fusedIndexMap.getNumDims());
581   for (auto numFoldedDim : llvm::enumerate(numExpandedDims)) {
582     auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
583     reassociation.emplace_back(seq.begin(), seq.end());
584     sum += numFoldedDim.value();
585   }
586   expandedOpNumDims = sum;
587   return success();
588 }
589 
590 /// Epanding the body of a linalg operation requires adaptations of the accessed
591 /// loop indices. Specifically, access of indices in the original operation need
592 /// to be replaced with linearizations of indices in the expanded op. That
593 /// requires the shape of the expanded dimensions to be static (at least all but
594 /// the most significant). For now check that these are all statically sized.
595 /// Note that this could be extended to handle dynamic case, but the
596 /// implementation below uses `affine.apply` which seems to have issues when the
597 /// shapes are not static.
598 LogicalResult isGenericOpExpandable(GenericOp genericOp,
599                                     const ExpansionInfo &expansionInfo,
600                                     PatternRewriter &rewriter) {
601   if (!genericOp.hasIndexSemantics())
602     return success();
603   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
604     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
605     if (expandedShape.size() == 1)
606       continue;
607     for (int64_t shape : expandedShape.drop_front()) {
608       if (ShapedType::isDynamic(shape)) {
609         return rewriter.notifyMatchFailure(
610             genericOp, "cannot expand due to index semantics and dynamic dims");
611       }
612     }
613   }
614   return success();
615 }
616 
617 /// Return the indexing map to use in the expanded op for a given the
618 /// `indexingMap` of the original operation.
619 static AffineMap
620 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
621                            const ExpansionInfo &expansionInfo) {
622   SmallVector<AffineExpr> newExprs;
623   for (AffineExpr expr : indexingMap.getResults()) {
624     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
625     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
626         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
627           return builder.getAffineDimExpr(static_cast<unsigned>(v));
628         }));
629     newExprs.append(expandedExprs.begin(), expandedExprs.end());
630   }
631   return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
632                         indexingMap.getNumSymbols(), newExprs,
633                         builder.getContext());
634 }
635 
636 /// Return the type of the operand/result to use in the expanded op given the
637 /// type in the original op.
638 static RankedTensorType getExpandedType(RankedTensorType originalType,
639                                         AffineMap indexingMap,
640                                         const ExpansionInfo &expansionInfo) {
641   SmallVector<int64_t> expandedShape;
642   for (AffineExpr expr : indexingMap.getResults()) {
643     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
644     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
645     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
646   }
647   return RankedTensorType::get(expandedShape, originalType.getElementType());
648 }
649 
650 /// Returns the reassociation maps to use in the `linalg.tensor_expand_shape`
651 /// operation to convert the operands of the original operation to operands of
652 /// the expanded operation. The same method is used to compute the
653 /// `linalg.tensor_collapse_shape` used to collapse the result of the expanded
654 /// op to get the value that can replace all uses of the results of the original
655 /// op.
656 static SmallVector<ReassociationIndices>
657 getReassociationForExpansion(AffineMap indexingMap,
658                              const ExpansionInfo &expansionInfo) {
659   SmallVector<ReassociationIndices> reassociation;
660   unsigned numReshapeDims = 0;
661   for (AffineExpr expr : indexingMap.getResults()) {
662     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
663     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
664     SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
665         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
666     reassociation.emplace_back(std::move(indices));
667     numReshapeDims += numExpandedDims;
668   }
669   return reassociation;
670 }
671 
672 /// Update the body of an expanded linalg operation having index semantics. The
673 /// indices of the original operation need to be recovered by linearizing the
674 /// indices of the correspoding dimensions of the expanded operation. For now it
675 /// is assumed that the shapes of the expanded operation needed for
676 /// linearization are static.
677 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
678                                           Location loc, Region &fusedRegion,
679                                           const ExpansionInfo &expansionInfo) {
680   // Replace the original indices by the linearization of the expanded indices.
681   for (IndexOp indexOp :
682        llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
683     ArrayRef<int64_t> expandedDims =
684         expansionInfo.getExpandedDims(indexOp.dim());
685     assert(!expandedDims.empty() && "expected valid expansion info");
686 
687     // Skip index operations that are not affected by the expansion.
688     if (expandedDims.size() == 1 &&
689         expandedDims.front() == (int64_t)indexOp.dim())
690       continue;
691 
692     // Linearize the expanded indices of the original index dimension.
693     OpBuilder::InsertionGuard guard(rewriter);
694     rewriter.setInsertionPointAfter(indexOp);
695     ArrayRef<int64_t> expandedDimsShape =
696         expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
697     SmallVector<Value> expandedIndices;
698     expandedIndices.reserve(expandedDims.size() - 1);
699     llvm::transform(
700         expandedDims.drop_front(), std::back_inserter(expandedIndices),
701         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
702     Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
703     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
704       assert(!ShapedType::isDynamic(std::get<0>(it)));
705       AffineExpr idx, acc;
706       bindDims(rewriter.getContext(), idx, acc);
707       newIndex = rewriter.create<AffineApplyOp>(
708           indexOp.getLoc(), idx + acc * std::get<0>(it),
709           ValueRange{std::get<1>(it), newIndex});
710     }
711     rewriter.replaceOp(indexOp, newIndex);
712   }
713 }
714 
715 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op
716 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
717 /// that those conditions have been satisfied.
718 static Optional<SmallVector<Value>>
719 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
720                            OpOperand *fusableOpOperand,
721                            PatternRewriter &rewriter) {
722   assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
723          "preconditions for fuse operation failed");
724   // Check if reshape is expanding or collapsing.
725   auto expandingReshapeOp = dyn_cast<TensorExpandShapeOp>(*reshapeOp);
726   auto collapsingReshapeOp = dyn_cast<TensorCollapseShapeOp>(*reshapeOp);
727   bool isExpanding = (expandingReshapeOp != nullptr);
728   RankedTensorType expandedType = isExpanding
729                                       ? expandingReshapeOp.getResultType()
730                                       : collapsingReshapeOp.getSrcType();
731 
732   ExpansionInfo expansionInfo;
733   if (failed(expansionInfo.compute(
734           genericOp, fusableOpOperand,
735           isExpanding ? expandingReshapeOp.getReassociationMaps()
736                       : collapsingReshapeOp.getReassociationMaps(),
737           expandedType.getShape(), rewriter)))
738     return llvm::None;
739 
740   if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
741     return llvm::None;
742 
743   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
744       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
745         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
746       }));
747 
748   SmallVector<Value> expandedOpOperands;
749   expandedOpOperands.reserve(genericOp.getNumInputs());
750   for (OpOperand *opOperand : genericOp.getInputOperands()) {
751     if (opOperand == fusableOpOperand) {
752       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
753                                                : collapsingReshapeOp.src());
754       continue;
755     }
756     if (genericOp.isInputTensor(opOperand)) {
757       AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
758       RankedTensorType expandedOperandType =
759           getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
760                           indexingMap, expansionInfo);
761       if (expandedOperandType != opOperand->get().getType()) {
762         // Reshape the operand to get the right type.
763         SmallVector<ReassociationIndices> reassociation =
764             getReassociationForExpansion(indexingMap, expansionInfo);
765         expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>(
766             genericOp.getLoc(), expandedOperandType, opOperand->get(),
767             reassociation));
768         continue;
769       }
770     }
771     expandedOpOperands.push_back(opOperand->get());
772   }
773 
774   Location loc = genericOp.getLoc();
775   SmallVector<Value> outputs;
776   for (OpOperand *opOperand : genericOp.getOutputOperands()) {
777     AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
778     RankedTensorType expandedOutputType =
779         getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
780                         indexingMap, expansionInfo);
781     if (expandedOutputType != opOperand->get().getType()) {
782       SmallVector<ReassociationIndices> reassociation =
783           getReassociationForExpansion(indexingMap, expansionInfo);
784       outputs.push_back(rewriter.create<TensorExpandShapeOp>(
785           genericOp.getLoc(), expandedOutputType, opOperand->get(),
786           reassociation));
787     }
788   }
789 
790   // The iterator types of the expanded op are all parallel.
791   SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
792                                        getParallelIteratorTypeName());
793 
794   TypeRange resultTypes = ValueRange(outputs).getTypes();
795   auto fusedOp =
796       rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
797                                  /*inputs=*/expandedOpOperands, outputs,
798                                  expandedOpIndexingMaps, iteratorTypes);
799   Region &fusedRegion = fusedOp->getRegion(0);
800   Region &originalRegion = genericOp->getRegion(0);
801   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
802 
803   // Update the index accesses after the expansion.
804   updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
805 
806   // Reshape the result values to their original shape if this is a collapsing
807   // reshape folded into its consumer.
808   SmallVector<Value> resultVals;
809   for (OpResult opResult : genericOp->getOpResults()) {
810     int64_t resultNumber = opResult.getResultNumber();
811     if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
812       SmallVector<ReassociationIndices> reassociation =
813           getReassociationForExpansion(
814               genericOp.getTiedIndexingMap(
815                   genericOp.getOutputOperand(resultNumber)),
816               expansionInfo);
817       resultVals.push_back(rewriter.create<TensorCollapseShapeOp>(
818           genericOp.getLoc(), opResult.getType(),
819           fusedOp->getResult(resultNumber), reassociation));
820     } else {
821       resultVals.push_back(fusedOp->getResult(resultNumber));
822     }
823   }
824   // Assuming a single result.
825   return resultVals;
826 }
827 
828 namespace {
829 
830 /// Pattern to fold tensor_expand_shape op with its consumer by using the source
831 /// of the reshape op as the operand in the consumer (instead of the result of
832 /// the tensor_collapse_shape). The corresponding index map in the consumer
833 /// needs to be modified to linearize the folded dimension.
834 ///
835 /// For example,
836 ///
837 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
838 /// %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2], [3]]
839 ///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
840 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
841 ///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
842 ///        -> tensor<?x?x4x?xf32>
843 ///
844 /// can be folded into
845 ///
846 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
847 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
848 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
849 ///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
850 ///        -> tensor<?x?x4x?xf32>
851 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
852 struct FoldProducerReshapeOpByLinearization
853     : public OpRewritePattern<GenericOp> {
854   using OpRewritePattern<GenericOp>::OpRewritePattern;
855 
856   LogicalResult matchAndRewrite(GenericOp genericOp,
857                                 PatternRewriter &rewriter) const override {
858     if (!genericOp.hasTensorSemantics())
859       return failure();
860     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
861     for (auto en : llvm::enumerate(inputOperands)) {
862       auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
863       if (!reshapeOp)
864         continue;
865 
866       if (!isTensorReshapeOpFoldableByLinearization(
867               reshapeOp, genericOp.getTiedIndexingMap(en.value()),
868               /*asProducer =*/true) ||
869           (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
870         continue;
871 
872       // Compute the fused operands list,
873       SmallVector<Value> fusedOperands = genericOp.getInputOperands();
874       fusedOperands[en.index()] = reshapeOp.src();
875       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
876       llvm::append_range(fusedOperands, outputOperands);
877 
878       // Compute indexing_maps for the fused operation. The indexing_maps for
879       // the operands of the consumers that arent fused are the same.
880       SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
881 
882       // Accepted consumer maps are either identity or permutation.
883       auto invMap = inversePermutation(fusedIndexMaps[en.index()]);
884 
885       // Compute the indexing map to use for the result of the producer.
886       AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp);
887       // The modified map cannot have symbols.
888       if (modifiedMap.getNumSymbols())
889         return failure();
890       for (AffineExpr expr : modifiedMap.getResults()) {
891         if (!expr.isPureAffine())
892           return failure();
893       }
894       fusedIndexMaps[en.index()] = modifiedMap;
895 
896       // Further check that the resulting index maps can be fused and
897       // inverted. Without this the resultant op is not legal.
898       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
899         return rewriter.notifyMatchFailure(
900             genericOp, "fused op loop bound computation failed");
901       }
902 
903       rewriter.startRootUpdate(genericOp);
904       genericOp->setOperands(fusedOperands);
905       genericOp.indexing_mapsAttr(
906           rewriter.getAffineMapArrayAttr(fusedIndexMaps));
907       rewriter.finalizeRootUpdate(genericOp);
908       return success();
909     }
910     return failure();
911   }
912 };
913 
914 static SmallVector<ReassociationIndices>
915 getReassociationIndices(ArrayRef<AffineMap> maps) {
916   SmallVector<ReassociationIndices> reassociation;
917   for (AffineMap map : maps) {
918     ReassociationIndices indices;
919     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
920       unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
921       indices.push_back(pos);
922     }
923     reassociation.push_back(indices);
924   }
925   return reassociation;
926 }
927 
928 /// Pattern to move rank reducing reshape after an elementwise linalg generic
929 /// op. This is useful to expose more fusion opportunities between named ops and
930 /// generic ops. This can only be done if there is no broadcast or permuation
931 /// within the dimensions we need to merge.
932 ///
933 /// For example,
934 ///
935 ///  %0 = linalg.tensor_expand_shape %A [[0, 1], [2]]
936 ///      : tensor<12544x16xf32> into tensor<112x112x16xf32>
937 ///  %2 = linalg.generic {indexing_maps = [
938 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
939 ///    affine_map<(d0, d1, d2) -> (d2)>,
940 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
941 ///    ["parallel", "parallel", "parallel"]} {
942 ///  } -> tensor<112x112x16xf32>
943 ///
944 ///  into
945 ///
946 ///  %2 = linalg.generic {indexing_maps = [
947 ///    affine_map<(d0, d1) -> (d0, d1)>,
948 ///    affine_map<(d0, d1) -> (d1)>,
949 ///    affine_map<(d0, d1) -> (d0, d1)>],
950 ///    iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
951 ///    : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
952 ///  } -> tensor<12544x16xf32>
953 ///  %3 = linalg.tensor_expand_shape %2 [[0, 1], [2]]
954 ///    : tensor<12544x16xf32> into tensor<112x112x16xf32>
955 struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
956   using OpRewritePattern<GenericOp>::OpRewritePattern;
957 
958   LogicalResult matchAndRewrite(GenericOp genericOp,
959                                 PatternRewriter &rewriter) const override {
960     // Only apply to elementwise linalg on tensor.
961     if (!genericOp.hasTensorSemantics() ||
962         genericOp.getNumParallelLoops() != genericOp.getNumLoops())
963       return failure();
964     // Only support identity output maps. It could be extended to permuations if
965     // needed.
966     if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
967           return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
968         }))
969       return failure();
970     int64_t destRank = genericOp.getNumParallelLoops();
971     SmallVector<Value> newOperands = genericOp.getInputOperands();
972     TensorExpandShapeOp reshapeFound;
973     // 1. Look for tensor_expand_shape operands and figure out save the
974     // dimensions merged.
975     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
976     for (auto en : llvm::enumerate(inputOperands)) {
977       auto reshapeOp =
978           en.value()->get().template getDefiningOp<TensorExpandShapeOp>();
979       if (!reshapeOp)
980         continue;
981       // TODO: We could support non-identity map as long as the merged
982       // dimensions are still contiguous.
983       if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
984         continue;
985       if (reshapeFound) {
986         // Only support a second reshape op if it has the same reassociate maps.
987         if (reshapeFound.getReassociationMaps() ==
988             reshapeOp.getReassociationMaps())
989           newOperands[en.index()] = reshapeOp.src();
990         continue;
991       }
992       reshapeFound = reshapeOp;
993       newOperands[en.index()] = reshapeOp.src();
994     }
995     if (!reshapeFound)
996       return failure();
997 
998     // Calculate the reassociation indices and rassociated reverse map.
999     SmallVector<ReassociationIndices> reassociation =
1000         getReassociationIndices(reshapeFound.getReassociationMaps());
1001     SmallVector<unsigned> remap(destRank);
1002     for (auto &indices : llvm::enumerate(reassociation)) {
1003       for (int64_t index : indices.value()) {
1004         remap[index] = indices.index();
1005       }
1006     }
1007     // 2. Verify that we can merge the dimensions in the linalg and that we
1008     // don't need to create new reshapes operands. Inserting new reshape
1009     // operands would defeat the purpose of the transformation.
1010     for (auto en : llvm::enumerate(inputOperands)) {
1011       if (en.value()->get() == newOperands[en.index()]) {
1012         AffineMap map = genericOp.getTiedIndexingMap(en.value());
1013         for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
1014           if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
1015             return failure();
1016         }
1017       }
1018     }
1019 
1020     // 3. Calculate the affine map remapping and the reassociation to apply to
1021     // output tensors.
1022     SmallVector<AffineMap> newMaps;
1023     unsigned newRank = reassociation.size();
1024     for (auto map : genericOp.getIndexingMaps()) {
1025       SmallVector<AffineExpr> newExprs;
1026       for (auto expr : map.getResults()) {
1027         unsigned position = expr.template cast<AffineDimExpr>().getPosition();
1028         // Skip dimension merged except for the last of the group.
1029         if (reassociation[remap[position]].back() == position) {
1030           newExprs.push_back(
1031               getAffineDimExpr(remap[position], genericOp.getContext()));
1032         }
1033       }
1034       newMaps.push_back(
1035           AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
1036     }
1037 
1038     // 4. Reshape the output tensors.
1039     SmallVector<Value> newOutputs;
1040     SmallVector<Type> newOutputTypes;
1041     for (auto output : genericOp.outputs()) {
1042       auto newOutputType = RankedTensorType::get(
1043           reshapeFound.getSrcType().getShape(),
1044           output.getType().template cast<RankedTensorType>().getElementType());
1045       Value newOutput = rewriter.create<TensorCollapseShapeOp>(
1046           genericOp->getLoc(), newOutputType, output, reassociation);
1047       newOutputTypes.push_back(newOutputType);
1048       newOutputs.push_back(newOutput);
1049     }
1050     // 5. Create a new generic op with lowerer rank.
1051     SmallVector<StringRef> iteratorTypes(newRank,
1052                                          getParallelIteratorTypeName());
1053     auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
1054                                             newOperands, newOutputs, newMaps,
1055                                             iteratorTypes);
1056     rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
1057                                 newOp.region().begin());
1058     // 6. Reshape the so that the type matches the uses.
1059     SmallVector<Value> newResults;
1060     for (auto result : llvm::enumerate(newOp->getResults())) {
1061       newResults.push_back(rewriter.create<TensorExpandShapeOp>(
1062           genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
1063           result.value(), reassociation));
1064     }
1065     rewriter.replaceOp(genericOp, newResults);
1066     return success();
1067   }
1068 };
1069 
1070 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
1071 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
1072 /// in the consumer is expanded.
1073 class FoldWithProducerReshapeOpByExpansion
1074     : public OpRewritePattern<GenericOp> {
1075 public:
1076   FoldWithProducerReshapeOpByExpansion(
1077       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1078       PatternBenefit benefit = 1)
1079       : OpRewritePattern<GenericOp>(context, benefit),
1080         controlFoldingReshapes(foldReshapes) {}
1081 
1082   LogicalResult matchAndRewrite(GenericOp genericOp,
1083                                 PatternRewriter &rewriter) const override {
1084     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1085       TensorCollapseShapeOp reshapeOp =
1086           opOperand->get().getDefiningOp<TensorCollapseShapeOp>();
1087       if (!reshapeOp)
1088         continue;
1089       // Fold only if
1090       // - The tensor reshape op is folding.
1091       // - All constraints of fusing with reshape by expansion are met.
1092       if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
1093           (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
1094         continue;
1095 
1096       Optional<SmallVector<Value>> replacementValues =
1097           fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
1098       if (!replacementValues)
1099         return failure();
1100       rewriter.replaceOp(genericOp, replacementValues.getValue());
1101       return success();
1102     }
1103     return failure();
1104   }
1105 
1106 private:
1107   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1108 };
1109 
1110 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
1111 /// producer. The corresponding index map in the consumer needs to be modified
1112 /// to linearize the folded dimension.
1113 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
1114 struct FoldConsumerReshapeOpByLinearization
1115     : public OpRewritePattern<TensorReshapeOp> {
1116   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1117 
1118   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1119                                 PatternRewriter &rewriter) const override {
1120     GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
1121     if (!producer || !producer.hasTensorSemantics() ||
1122         producer.getNumOutputs() != 1 ||
1123         !isTensorReshapeOpFoldableByLinearization(
1124             reshapeOp,
1125             producer.getTiedIndexingMap(producer.getOutputOperand(0)),
1126             /*asProducer =*/false) ||
1127         (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
1128       return failure();
1129     // The indexing_maps for the operands of the fused operation are same as
1130     // those for the operands of the producer.
1131     SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
1132 
1133     auto invMap = inversePermutation(
1134         producer.getTiedIndexingMap(producer.getOutputOperand(0)));
1135 
1136     // Compute the indexing map to use for the operand of the producer.
1137     AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp);
1138     for (AffineExpr expr : modifiedMap.getResults()) {
1139       if (!expr.isPureAffine()) {
1140         return rewriter.notifyMatchFailure(
1141             producer, "fused op indexing map is not affine");
1142       }
1143     }
1144     fusedIndexMaps.back() = modifiedMap;
1145 
1146     // Further check that the resulting index maps can be fused and
1147     // inverted. Without this the resultant op is not legal.
1148     if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1149       return rewriter.notifyMatchFailure(
1150           producer, "fused op loop bound computation failed");
1151     }
1152 
1153     Location loc = producer.getLoc();
1154     SmallVector<Value> inputOperands = producer.getInputOperands();
1155     Value output = rewriter.create<TensorReshapeOp>(
1156         loc, producer.getOutputOperand(0)->get(),
1157         reshapeOp.getReassociationExprs());
1158     auto fusedOp = rewriter.create<GenericOp>(
1159         loc, reshapeOp.getResultType(),
1160         /*inputs=*/inputOperands,
1161         // TODO: handle outputs.
1162         /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1163         producer.iterator_types(),
1164         /*doc=*/nullptr,
1165         /*library_call=*/nullptr);
1166     auto &fusedRegion = fusedOp->getRegion(0);
1167     rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
1168                                fusedRegion.begin());
1169     rewriter.replaceOp(reshapeOp, fusedOp->getResults());
1170     return success();
1171   }
1172 };
1173 
1174 /// Pattern to fold a tensor_expand_shape op with its producer generic op
1175 /// by expanding the dimensionality of the loop in the producer op.
1176 struct FoldReshapeWithGenericOpByExpansion
1177     : public OpRewritePattern<TensorExpandShapeOp> {
1178 
1179   FoldReshapeWithGenericOpByExpansion(
1180       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1181       PatternBenefit benefit = 1)
1182       : OpRewritePattern<TensorExpandShapeOp>(context, benefit),
1183         controlFoldingReshapes(foldReshapes) {}
1184 
1185   LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp,
1186                                 PatternRewriter &rewriter) const override {
1187     // Fold only if all constraints of fusing with reshape by expansion are met.
1188     GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
1189     if (!producer || producer.getNumOutputs() != 1 ||
1190         !isFusableWithReshapeByDimExpansion(producer,
1191                                             producer.getOutputOperand(0)) ||
1192         !controlFoldingReshapes(producer->getResult(0),
1193                                 reshapeOp->getOpOperand(0)))
1194       return failure();
1195     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
1196         producer, reshapeOp, producer.getOutputOperand(0), rewriter);
1197     if (!replacementValues)
1198       return failure();
1199     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
1200     return success();
1201   }
1202 
1203 private:
1204   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1205 };
1206 
1207 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1208 /// handle cases where the constant is not single-valued.
1209 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1210 public:
1211   FoldScalarOrSplatConstant(MLIRContext *context,
1212                             ControlElementwiseOpsFusionFn &fun,
1213                             PatternBenefit benefit = 1)
1214       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
1215 
1216   LogicalResult matchAndRewrite(GenericOp genericOp,
1217                                 PatternRewriter &rewriter) const override {
1218     if (!genericOp.hasTensorSemantics())
1219       return failure();
1220     for (OpOperand *opOperand : genericOp.getInputOperands()) {
1221       Operation *def = opOperand->get().getDefiningOp();
1222       Attribute constantAttr;
1223       auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1224         {
1225           DenseElementsAttr splatAttr;
1226           if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1227               splatAttr.isSplat() &&
1228               splatAttr.getType().getElementType().isIntOrFloat()) {
1229             constantAttr = splatAttr.getSplatValue<Attribute>();
1230             return true;
1231           }
1232         }
1233         {
1234           IntegerAttr intAttr;
1235           if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1236             constantAttr = intAttr;
1237             return true;
1238           }
1239         }
1240         {
1241           FloatAttr floatAttr;
1242           if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1243             constantAttr = floatAttr;
1244             return true;
1245           }
1246         }
1247         return false;
1248       };
1249 
1250       auto resultValue = opOperand->get().dyn_cast<OpResult>();
1251       if (!def || !resultValue || !isScalarOrSplatConstantOp(def) ||
1252           !controlFn(resultValue, *opOperand))
1253         continue;
1254 
1255       // The operands and the indexing_maps of the fused operation the same as
1256       // the operands and indexing_maps of the generic operations with the
1257       // values at the constant index dropped.
1258       SmallVector<AffineMap> fusedIndexMaps;
1259       SmallVector<Value> fusedOperands;
1260       SmallVector<Location> fusedLocs{genericOp.getLoc()};
1261       fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1262       fusedOperands.reserve(genericOp.getNumInputs());
1263       fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
1264       for (OpOperand *inputOperand : genericOp.getInputOperands()) {
1265         if (inputOperand == opOperand)
1266           continue;
1267         Value inputValue = inputOperand->get();
1268         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1269         fusedOperands.push_back(inputValue);
1270         fusedLocs.push_back(inputValue.getLoc());
1271       }
1272       for (OpOperand *outputOperand : genericOp.getOutputOperands())
1273         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1274 
1275       // Check if the operation shapes to loops map is computable.
1276       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1277         return rewriter.notifyMatchFailure(
1278             genericOp, "fused op loop bound computation failed");
1279       }
1280 
1281       // Create a constant scalar value from the splat constant.
1282       Value scalarConstant = rewriter.create<arith::ConstantOp>(
1283           def->getLoc(), constantAttr, constantAttr.getType());
1284 
1285       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1286       auto fusedOp = rewriter.create<GenericOp>(
1287           rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1288           /*inputs=*/fusedOperands,
1289           /*outputs=*/outputOperands,
1290           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1291           genericOp.iterator_types(),
1292           /*doc=*/nullptr,
1293           /*library_call=*/nullptr);
1294 
1295       // Map the block argument corresponding to the replaced argument with the
1296       // scalar constant.
1297       Region &region = genericOp->getRegion(0);
1298       Block &entryBlock = *region.begin();
1299       BlockAndValueMapping mapping;
1300       mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1301                   scalarConstant);
1302       Region &fusedRegion = fusedOp->getRegion(0);
1303       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1304                                  mapping);
1305       rewriter.replaceOp(genericOp, fusedOp->getResults());
1306       return success();
1307     }
1308     return failure();
1309   }
1310 
1311 private:
1312   ControlElementwiseOpsFusionFn controlFn;
1313 };
1314 
1315 /// Base class for constant folding linalg.generic ops with N inputs, 1 output,
1316 /// and permutation indexing maps.
1317 ///
1318 /// `ConcreteType` should provide methods with signatures
1319 ///
1320 /// ```c++
1321 ///   bool matchIndexingMaps(GenericOp genericOp) const;
1322 ///   RegionComputationFn getRegionComputeFn(GenericOp) const;
1323 /// ```
1324 ///
1325 /// The latter inspects the region and returns the computation inside as a
1326 /// functor. The functor will be invoked with constant elements for all inputs
1327 /// and should return the corresponding computea constant element for output.
1328 template <typename ConcreteType>
1329 class FoldConstantBase : public OpRewritePattern<GenericOp> {
1330 public:
1331   struct APIntOrFloat {
1332     Optional<APInt> apInt;
1333     Optional<APFloat> apFloat;
1334   };
1335   struct APIntOrFloatArray {
1336     SmallVector<APInt> apInts;
1337     SmallVector<APFloat> apFloats;
1338   };
1339   using RegionComputationFn =
1340       std::function<APIntOrFloat(const APIntOrFloatArray &)>;
1341 
1342   FoldConstantBase(MLIRContext *context,
1343                    const ControlElementwiseOpsFusionFn &controlFn,
1344                    PatternBenefit benefit = 1)
1345       : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
1346 
1347   LogicalResult matchAndRewrite(GenericOp genericOp,
1348                                 PatternRewriter &rewriter) const override {
1349     if (genericOp.hasBufferSemantics())
1350       return failure();
1351 
1352     // Only support ops generating one output for now.
1353     if (genericOp.getNumOutputs() != 1)
1354       return failure();
1355 
1356     auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
1357     // Require the output types to be static give we are generating constants.
1358     if (!outputType || !outputType.hasStaticShape())
1359       return failure();
1360 
1361     if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
1362           return operand->get().getType().isa<ShapedType>();
1363         }))
1364       return failure();
1365 
1366     // Make sure all element types are the same.
1367     auto getOperandElementType = [](OpOperand *operand) {
1368       return operand->get().getType().cast<ShapedType>().getElementType();
1369     };
1370     if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
1371                                         getOperandElementType)))
1372       return failure();
1373 
1374     // We can only handle the case where we have int/float elements.
1375     auto elementType = outputType.getElementType();
1376     if (!elementType.isIntOrFloat())
1377       return failure();
1378 
1379     // Require all indexing maps to be permutations for now. This is common and
1380     // it simplifies input/output access greatly: we can do the data shuffling
1381     // entirely in the compiler, without needing to turn all indices into
1382     // Values, and then do affine apply on them, and then match back the
1383     // constant again.
1384     if (!llvm::all_of(genericOp.getIndexingMaps(),
1385                       [](AffineMap map) { return map.isPermutation(); }))
1386       return failure();
1387 
1388     for (OpOperand *operand : genericOp.getOutputOperands()) {
1389       if (genericOp.payloadUsesValueFromOperand(operand))
1390         return failure();
1391     }
1392 
1393     // Further check the indexing maps are okay for the ConcreteType.
1394     if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
1395       return failure();
1396 
1397     // Defer to the concrete type to check the region and discover the
1398     // computation inside.
1399     RegionComputationFn computeFn =
1400         static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
1401     if (!computeFn)
1402       return failure();
1403 
1404     // All inputs should be constants.
1405     int numInputs = genericOp.getNumInputs();
1406     SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
1407     for (auto operand : llvm::enumerate(genericOp.getInputOperands())) {
1408       if (!matchPattern(operand.value()->get(),
1409                         m_Constant(&inputValues[operand.index()])))
1410         return failure();
1411     }
1412 
1413     // Identified this as a potential candidate for folding. Now check the
1414     // policy to see whether we are allowed to proceed.
1415     for (int i = 0; i < numInputs; ++i) {
1416       OpOperand *consumer = genericOp.getInputOperand(i);
1417       OpResult producer = consumer->get().cast<OpResult>();
1418       if (!controlFn(producer, *consumer))
1419         return failure();
1420     }
1421 
1422     auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
1423     SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
1424     int64_t numElements = outputType.getNumElements();
1425 
1426     // Use APInt/APFloat instead of Attribute here for constructing the output.
1427     // This helps to avoid blowing up compiler memory usage: Attributes would
1428     // unify the following cases but they have lifetime as the MLIRContext.
1429     SmallVector<APInt> intOutputValues;
1430     SmallVector<APFloat> fpOutputValues;
1431     if (elementType.template isa<FloatType>())
1432       fpOutputValues.resize(numElements, APFloat(0.f));
1433     else
1434       intOutputValues.resize(numElements);
1435 
1436     // Return the constant dim positions from the given permutation map.
1437     auto getDimPositions = [](AffineMap map) {
1438       SmallVector<unsigned> dims;
1439       dims.reserve(map.getNumResults());
1440       for (AffineExpr result : map.getResults()) {
1441         dims.push_back(result.cast<AffineDimExpr>().getPosition());
1442       }
1443       return dims;
1444     };
1445 
1446     SmallVector<SmallVector<unsigned>> inputDims;
1447     for (int i = 0; i < numInputs; ++i)
1448       inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i]));
1449     auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
1450     auto outputShape = outputType.getShape();
1451 
1452     // Allocate small vectors for index delinearization. Initial values do not
1453     // matter here as they will be overwritten later.
1454     SmallVector<uint64_t> indices(loopBounds.size(), 0);
1455     SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
1456     SmallVector<SmallVector<uint64_t>> srcIndices(
1457         numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
1458     SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
1459     uint64_t dstLinearIndex = 0;
1460 
1461     // Allocate spaces for compute function inputs. Initial values do not matter
1462     // here as they will be overwritten later.
1463     APIntOrFloatArray computeFnInputs;
1464 
1465     auto inputShapes = llvm::to_vector<4>(
1466         llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
1467           return operand->get().getType().cast<ShapedType>().getShape();
1468         }));
1469 
1470     // Given a `linearIndex`, remap it to a linear index to access linalg op
1471     // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
1472     // `srcLinearIndices`, `dstLinearIndex` in place.
1473     auto computeRemappedLinearIndex = [&](int linearIndex) {
1474       int totalCount = linearIndex;
1475       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
1476         indices[dim] = totalCount % loopBounds[dim];
1477         totalCount /= loopBounds[dim];
1478       }
1479 
1480       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
1481         for (int i = 0; i < numInputs; ++i)
1482           srcIndices[i][dim] = indices[inputDims[i][dim]];
1483         dstIndices[dim] = indices[outputDims[dim]];
1484       }
1485 
1486       dstLinearIndex = dstIndices.front();
1487       for (int i = 0; i < numInputs; ++i)
1488         srcLinearIndices[i] = srcIndices[i].front();
1489 
1490       for (int dim = 1; dim < outputType.getRank(); ++dim) {
1491         dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
1492         for (int i = 0; i < numInputs; ++i)
1493           srcLinearIndices[i] =
1494               srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
1495       }
1496     };
1497 
1498     bool isFloat = elementType.isa<FloatType>();
1499     if (isFloat) {
1500       SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
1501       for (int i = 0; i < numInputs; ++i)
1502         inFpRanges.push_back(inputValues[i].getValues<APFloat>());
1503 
1504       computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
1505 
1506       // Transpose the input constant. Because we don't know its rank in
1507       // advance, we need to loop over the range [0, element count) and
1508       // delinearize the index.
1509       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
1510         computeRemappedLinearIndex(linearIndex);
1511 
1512         // Collect constant elements for all inputs at this loop iteration.
1513         for (int i = 0; i < numInputs; ++i)
1514           computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
1515 
1516         // Invoke the computation to get the corresponding constant output
1517         // element.
1518         fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
1519       }
1520     } else {
1521       SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
1522       for (int i = 0; i < numInputs; ++i)
1523         inIntRanges.push_back(inputValues[i].getValues<APInt>());
1524 
1525       computeFnInputs.apInts.resize(numInputs);
1526 
1527       // Transpose the input constant. Because we don't know its rank in
1528       // advance, we need to loop over the range [0, element count) and
1529       // delinearize the index.
1530       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
1531         computeRemappedLinearIndex(linearIndex);
1532 
1533         // Collect constant elements for all inputs at this loop iteration.
1534         for (int i = 0; i < numInputs; ++i)
1535           computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
1536 
1537         // Invoke the computation to get the corresponding constant output
1538         // element.
1539         intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
1540       }
1541     }
1542 
1543     DenseElementsAttr outputAttr =
1544         isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
1545                 : DenseElementsAttr::get(outputType, intOutputValues);
1546 
1547     rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr);
1548     return success();
1549   }
1550 
1551 private:
1552   ControlElementwiseOpsFusionFn controlFn;
1553 };
1554 
1555 // Folds linalg.generic ops that are actually transposes on constant values.
1556 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
1557   using FoldConstantBase::FoldConstantBase;
1558 
1559   bool matchIndexingMaps(GenericOp genericOp) const {
1560     // We should have one input and one output.
1561     return genericOp.getIndexingMaps().size() == 2;
1562   }
1563 
1564   RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
1565     // Make sure the region only contains a yield op.
1566     Block &body = genericOp.region().front();
1567     if (!llvm::hasSingleElement(body))
1568       return nullptr;
1569     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1570     if (!yieldOp)
1571       return nullptr;
1572 
1573     // The yield op should return the block argument corresponds to the input.
1574     for (Value yieldVal : yieldOp.values()) {
1575       auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
1576       if (!yieldArg || yieldArg.getOwner() != &body)
1577         return nullptr;
1578       if (yieldArg.getArgNumber() != 0)
1579         return nullptr;
1580     }
1581 
1582     // No computation; just return the orginal value.
1583     return [](const APIntOrFloatArray &inputs) {
1584       if (inputs.apFloats.empty())
1585         return APIntOrFloat{inputs.apInts.front(), llvm::None};
1586       return APIntOrFloat{llvm::None, inputs.apFloats.front()};
1587     };
1588   }
1589 
1590   ControlElementwiseOpsFusionFn controlFn;
1591 };
1592 
1593 } // namespace
1594 
1595 static Optional<SmallVector<Value>>
1596 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
1597                    GenericOp producer,
1598                    const ControlElementwiseOpsFusionFn &controlFn) {
1599   if (producer->getNumResults() != 1)
1600     return llvm::None;
1601 
1602   return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
1603                                 rewriter);
1604 }
1605 
1606 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
1607                                       OpOperand &consumer) {
1608   if (auto producerCollapseOp =
1609           dyn_cast<linalg::TensorCollapseShapeOp>(producer.getOwner())) {
1610     return !isUnitDimExpansionOnly(producerCollapseOp);
1611   }
1612   if (auto consumerExpandOp =
1613           dyn_cast<linalg::TensorExpandShapeOp>(consumer.getOwner())) {
1614     return !isUnitDimExpansionOnly(consumerExpandOp);
1615   }
1616   return true;
1617 }
1618 
1619 namespace {
1620 /// Patterns to fuse a generic op, with the producer of its operands.
1621 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
1622 public:
1623   FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
1624                      PatternBenefit benefit = 1)
1625       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
1626 
1627   LogicalResult matchAndRewrite(GenericOp genericOp,
1628                                 PatternRewriter &rewriter) const override {
1629     // Find the first operand that is defined by another generic op on tensors.
1630     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
1631       auto producer =
1632           dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
1633       if (!producer || !producer.hasTensorSemantics())
1634         continue;
1635       Optional<SmallVector<Value>> fusedOpResults =
1636           fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
1637       if (fusedOpResults) {
1638         rewriter.replaceOp(genericOp, *fusedOpResults);
1639         return success();
1640       }
1641     }
1642     return failure();
1643   }
1644 
1645 private:
1646   ControlElementwiseOpsFusionFn controlFn;
1647 };
1648 
1649 /// Pass that fuses generic ops on tensors. Used only for testing.
1650 struct LinalgElementwiseOpFusionPass
1651     : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
1652   void runOnOperation() override {
1653     Operation *op = getOperation();
1654     RewritePatternSet patterns(op->getContext());
1655     ControlElementwiseOpsFusionFn allowFoldingFn =
1656         [](const OpResult &producer, const OpOperand &consumer) {
1657           return true;
1658         };
1659     populateElementwiseOpsFusionPatterns(
1660         patterns,
1661         LinalgElementwiseFusionOptions().setControlFoldingReshapes(
1662             allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
1663 
1664     // Use TopDownTraversal for compile time reasons
1665     GreedyRewriteConfig grc;
1666     grc.useTopDownTraversal = true;
1667     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
1668                                        grc);
1669   }
1670 };
1671 
1672 /// Pass to test folding of reshape ops with generic ops by linearization.
1673 struct FoldReshapeOpsByLinearizationPass
1674     : public LinalgFoldReshapeOpsByLinearizationBase<
1675           FoldReshapeOpsByLinearizationPass> {
1676   void runOnOperation() override {
1677     Operation *op = getOperation();
1678     RewritePatternSet patterns(op->getContext());
1679     populateFoldReshapeOpsByLinearizationPatterns(patterns);
1680     if (allowFoldingUnitDimReshapes) {
1681       populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
1682     }
1683     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
1684   }
1685 };
1686 
1687 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
1688 /// the value of the `outs` operand is not used within the op.  This is only
1689 /// implemented for `linalg.generic` operations for now, but should hold for all
1690 /// linalg structured ops.
1691 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1692   using OpRewritePattern<GenericOp>::OpRewritePattern;
1693 
1694   LogicalResult matchAndRewrite(GenericOp op,
1695                                 PatternRewriter &rewriter) const override {
1696     rewriter.startRootUpdate(op);
1697     bool modifiedOutput = false;
1698     Location loc = op.getLoc();
1699     for (OpOperand *opOperand : op.getOutputOperands()) {
1700       if (!op.payloadUsesValueFromOperand(opOperand)) {
1701         Value operandVal = opOperand->get();
1702         auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
1703         if (!operandType)
1704           continue;
1705 
1706         // If outs is already an `init_tensor` operation, nothing to do.
1707         auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
1708         if (definingOp)
1709           continue;
1710         modifiedOutput = true;
1711         SmallVector<Value> dynamicDims;
1712         for (auto dim : llvm::enumerate(operandType.getShape())) {
1713           if (dim.value() != ShapedType::kDynamicSize)
1714             continue;
1715           dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
1716               loc, operandVal, dim.index()));
1717         }
1718         Value initTensor = rewriter.create<InitTensorOp>(
1719             loc, dynamicDims, operandType.getShape(),
1720             operandType.getElementType());
1721         op->setOperand(opOperand->getOperandNumber(), initTensor);
1722       }
1723     }
1724     if (!modifiedOutput) {
1725       rewriter.cancelRootUpdate(op);
1726       return failure();
1727     }
1728     rewriter.finalizeRootUpdate(op);
1729     return success();
1730   }
1731 };
1732 
1733 } // namespace
1734 
1735 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
1736     RewritePatternSet &patterns) {
1737   patterns
1738       .add<FoldProducerReshapeOpByLinearization<false, TensorCollapseShapeOp>,
1739            FoldProducerReshapeOpByLinearization<false, TensorExpandShapeOp>,
1740            FoldConsumerReshapeOpByLinearization<false, TensorCollapseShapeOp>,
1741            FoldConsumerReshapeOpByLinearization<false, TensorExpandShapeOp>>(
1742           patterns.getContext());
1743 }
1744 
1745 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
1746     RewritePatternSet &patterns) {
1747   patterns
1748       .add<FoldProducerReshapeOpByLinearization<true, TensorCollapseShapeOp>,
1749            FoldProducerReshapeOpByLinearization<true, TensorExpandShapeOp>,
1750            FoldConsumerReshapeOpByLinearization<true, TensorCollapseShapeOp>,
1751            FoldConsumerReshapeOpByLinearization<true, TensorExpandShapeOp>>(
1752           patterns.getContext());
1753 }
1754 
1755 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
1756     RewritePatternSet &patterns,
1757     ControlElementwiseOpsFusionFn controlFoldingReshapes) {
1758   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
1759                                                     controlFoldingReshapes);
1760   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
1761                                                      controlFoldingReshapes);
1762 }
1763 
1764 void mlir::linalg::populateElementwiseOpsFusionPatterns(
1765     RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
1766   auto *context = patterns.getContext();
1767   patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
1768                FoldConstantTranspose>(context,
1769                                       options.controlElementwiseOpsFusionFn);
1770   patterns.add<RemoveOutsDependency>(context);
1771   populateFoldReshapeOpsByExpansionPatterns(patterns,
1772                                             options.controlFoldingReshapesFn);
1773   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1774   GenericOp::getCanonicalizationPatterns(patterns, context);
1775   TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1776   TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1777   context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1778       patterns);
1779 }
1780 
1781 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
1782   auto *context = patterns.getContext();
1783   patterns.add<PushExpandingReshape>(context);
1784 }
1785 
1786 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
1787   return std::make_unique<LinalgElementwiseOpFusionPass>();
1788 }
1789 
1790 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
1791   return std::make_unique<FoldReshapeOpsByLinearizationPass>();
1792 }
1793