1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include <utility>
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Support/LLVM.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 
30 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
31 /// the `producer` to use in the fused operation given the indexing map of the
32 /// result of the producer in the consumer.
33 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
34     OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
35     AffineMap fusedConsumerArgIndexMap) {
36   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
37   // from consumer loop -> consumer arg tensor index/producer result tensor
38   // index. The fused loop is same as the consumer loop. For each producer arg
39   // the indexing map to be computed is a map from consumer loop -> producer
40   // arg tensor index.
41   // producerResultIndexMap is a map from producer loop -> tensor index.
42   // Compute the inverse to get map from tensor index -> producer loop.
43   // The inverse is a map from producer result tensor index -> producer loop.
44   AffineMap invProducerResultIndexMap =
45       inversePermutation(producerResultIndexMap);
46   assert(invProducerResultIndexMap &&
47          "expected producer result indexig map to be invertible");
48 
49   LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
50   // argMap is a map from producer loop -> producer arg tensor index.
51   AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
52 
53   // Compose argMap with invProducerResultIndexMap to get a map from
54   // producer result tensor index -> producer arg tensor index.
55   AffineMap t1 = argMap.compose(invProducerResultIndexMap);
56 
57   // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
58   // consumer loop/ fused loop -> producer arg tensor index.
59   return t1.compose(fusedConsumerArgIndexMap);
60 }
61 
62 /// Conditions for elementwise fusion of generic operations.
63 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
64                                      OpOperand *consumerOpOperand) {
65   // Producer and consumer must have tensor semantics.
66   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
67     return false;
68 
69   // Verify that
70   // - the producer has all "parallel" iterator type.
71   if (producer.getNumParallelLoops() != producer.getNumLoops())
72     return false;
73 
74   // Only allow fusing the producer of an input operand for now.
75   // TODO: allow fusing the producer of an output operand.
76   if (!consumer.isInputTensor(consumerOpOperand))
77     return false;
78 
79   // Get the consumer index map. The number of results of the consumer index
80   // map must match the number of loops of the producer.
81   AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
82   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
83     return false;
84 
85   // Currently support only operations with single result.
86   if (producer.getNumOutputs() != 1)
87     return false;
88 
89   // Finally the index_map for the result must be invertible. For now just
90   // verify it is a permutation.
91   AffineMap producerResultIndexMap =
92       producer.getTiedIndexingMap(producer.getOutputOperand(0));
93   if (!producerResultIndexMap.isPermutation())
94     return false;
95 
96   // Ensure that the fusion does not remove size information required to
97   // get the loop bounds. For non-reduction generics, this is trivially the
98   // case due to the output operand. For reductions, we need to check that after
99   // the fusion, each loop dimension has at least one input that defines it.
100   if ((consumer.getNumReductionLoops())) {
101     llvm::BitVector coveredDims(consumer.getNumLoops(), false);
102 
103     auto addToCoveredDims = [&](AffineMap map) {
104       for (auto result : map.getResults())
105         if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
106           coveredDims[dimExpr.getPosition()] = true;
107     };
108 
109     for (auto pair :
110          llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) {
111       Value operand = std::get<0>(pair);
112       if (operand == consumerOpOperand->get())
113         continue;
114       AffineMap operandMap = std::get<1>(pair);
115       addToCoveredDims(operandMap);
116     }
117 
118     for (OpOperand *operand : producer.getInputOperands()) {
119       AffineMap newIndexingMap =
120           getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
121               operand, producerResultIndexMap, consumerIndexMap);
122       addToCoveredDims(newIndexingMap);
123     }
124     if (!coveredDims.all())
125       return false;
126   }
127 
128   return true;
129 }
130 
131 /// Generate the region of the fused tensor operation. The region of the fused
132 /// op must be empty.
133 static void
134 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
135                                  AffineMap consumerToProducerLoopsMap,
136                                  OpOperand *consumerOpOperand,
137                                  unsigned nloops) {
138   auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
139   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
140   // Build the region of the fused op.
141   Block &producerBlock = producer->getRegion(0).front();
142   Block &consumerBlock = consumer->getRegion(0).front();
143   Block *fusedBlock = new Block();
144   fusedOp.region().push_back(fusedBlock);
145   BlockAndValueMapping mapper;
146   OpBuilder::InsertionGuard guard(rewriter);
147   rewriter.setInsertionPointToStart(fusedBlock);
148 
149   // 2. Add an index operation for every fused loop dimension and use the
150   // `consumerToProducerLoopsMap` to map the producer indices.
151   if (producer.hasIndexSemantics()) {
152     // Add an index operation for every fused loop dimension.
153     unsigned numFusedOpLoops =
154         std::max(producer.getNumLoops(), consumer.getNumLoops());
155     SmallVector<Value> fusedIndices;
156     fusedIndices.reserve(numFusedOpLoops);
157     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
158                     std::back_inserter(fusedIndices), [&](uint64_t dim) {
159                       return rewriter.create<IndexOp>(producer.getLoc(), dim);
160                     });
161     for (IndexOp indexOp :
162          llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
163       Value newIndex = rewriter.create<mlir::AffineApplyOp>(
164           producer.getLoc(),
165           consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
166       mapper.map(indexOp.getResult(), newIndex);
167     }
168   }
169   // TODO: allow fusing the producer of an output operand.
170   assert(consumer.isInputTensor(consumerOpOperand) &&
171          "expected producer of input operand");
172   // 3. Consumer input operands up to consumerIdx (exclusive).
173   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
174            consumerOpOperand->getOperandNumber())) // input assumption.
175     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
176 
177   // Replacing consumerIdx requires getting the cloned, yielded, value from
178   // the (cloned) producer block. This happens in step 9.
179 
180   // 4. Splice in producer's input operands.
181   for (BlockArgument bbArg :
182        producerBlock.getArguments().take_front(producer.getNumInputs()))
183     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
184 
185   // 4.b. Producer output operand/map that is fused needs to be mapped to the
186   // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
187   assert(producer->getNumResults() == 1 && "expected single result producer");
188   if (producer.isInitTensor(producer.getOutputOperand(0))) {
189     BlockArgument bbArg = producerBlock.getArguments()
190                               .drop_front(producer.getNumInputs())
191                               // TODO: bbArg index of
192                               .front();
193     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
194   }
195   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
196   for (BlockArgument bbArg :
197        consumerBlock.getArguments()
198            .take_front(consumer.getNumInputs())
199            .drop_front(consumerOpOperand->getOperandNumber() + 1))
200     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
201   // 6. All of consumer's output operands.
202   for (BlockArgument bbArg :
203        consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
204     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
205   // 7. All of producer's output operands except the one fused.
206   // TODO: allow fusion of multi-result producers.
207   assert(producer->getNumResults() == 1 && "expected single result producer");
208 
209   // 8. Clone all producer operations except for the yield and index operations
210   // to the fused operation.
211   for (auto &op : producerBlock.without_terminator()) {
212     if (!isa<IndexOp>(op))
213       rewriter.clone(op, mapper);
214   }
215   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
216   // forward the yield operand.
217   auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
218   // TODO: allow fusion of multi-result producers.
219   assert(producer->getNumResults() == 1 && "expected single result producer");
220   unsigned producerResultNumber = 0;
221   Value replacement =
222       mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
223   // Sanity checks, if replacement is not already in the mapper then it must be
224   // produced outside.
225   if (replacement == yieldOp.getOperand(producerResultNumber)) {
226     if (auto bb = replacement.dyn_cast<BlockArgument>())
227       assert(bb.getOwner() != &producerBlock &&
228              "yielded block argument must have been mapped");
229     else
230       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
231              "yielded value must have been mapped");
232   }
233   mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
234              replacement);
235   // 10. Clone operations from the consumer to the fused op.
236   for (auto &op : consumerBlock.getOperations())
237     rewriter.clone(op, mapper);
238 
239   // Sanity checks.
240   assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
241          "Ill-formed GenericOp region");
242 }
243 
244 static Optional<SmallVector<Value>>
245 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
246                        const ControlElementwiseOpsFusionFn &controlFn,
247                        PatternRewriter &rewriter) {
248   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
249   if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
250       !controlFn(producer->getResult(0), *consumerOpOperand))
251     return llvm::None;
252 
253   // TODO: allow fusing the producer of an output operand.
254   assert(consumer.isInputTensor(consumerOpOperand) &&
255          "expected producer of input operand");
256 
257   // Compute the fused operands list and indexing maps.
258   SmallVector<Value> fusedOperands;
259   SmallVector<AffineMap> fusedIndexMaps;
260   fusedOperands.reserve(producer->getNumOperands() +
261                         consumer->getNumOperands());
262   fusedIndexMaps.reserve(producer->getNumOperands() +
263                          consumer->getNumOperands());
264   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
265   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
266   SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
267   SmallVector<OpOperand *>::iterator it =
268       llvm::find(consumerInputs, consumerOpOperand);
269   assert(it != consumerInputs.end() && "expected to find the consumer operand");
270   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
271     fusedOperands.push_back(opOperand->get());
272     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
273   }
274   // 4. Splice in producer's input operands/maps.
275   assert(producer->getNumResults() == 1 && "expected single result producer");
276   AffineMap producerResultIndexMap =
277       producer.getTiedIndexingMap(producer.getOutputOperand(0));
278   for (OpOperand *opOperand : producer.getInputOperands()) {
279     fusedOperands.push_back(opOperand->get());
280     // Compute indexing maps for the producer args in the fused operation.
281     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
282         opOperand, producerResultIndexMap,
283         consumer.getTiedIndexingMap(consumerOpOperand));
284     fusedIndexMaps.push_back(map);
285   }
286   // 4.b. Producer output operand/map that is fused needs to be passed if it is
287   // an "initTensor" (i.e. its value is actually read).
288   assert(producer->getNumResults() == 1 && "expected single result producer");
289   if (producer.isInitTensor(producer.getOutputOperand(0))) {
290     fusedOperands.push_back(producer.getOutputOperand(0)->get());
291     // Compute indexing maps for the producer args in the fused operation.
292     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
293         producer.getOutputOperand(0), producerResultIndexMap,
294         consumer.getTiedIndexingMap(consumerOpOperand));
295     fusedIndexMaps.push_back(map);
296   }
297   // 5. Remaining consumer's input operands/maps (drop past index
298   // `consumerIdx`).
299   for (OpOperand *opOperand :
300        llvm::make_range(std::next(it), consumerInputs.end())) {
301     fusedOperands.push_back(opOperand->get());
302     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
303   }
304   // 6. All of consumer's output operands (skip operands: added by the builder).
305   for (OpOperand *opOperand : consumer.getOutputOperands())
306     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
307   // 7. All of producer's output operands/maps except the one fused.
308   // TODO: allow fusion of multi-result producers.
309   assert(producer->getNumResults() == 1 && "expected single result producer");
310 
311   // Generate the fused op.
312   SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
313   auto fusedOp = rewriter.create<GenericOp>(
314       consumer.getLoc(), consumer->getResultTypes(),
315       /*inputs=*/fusedOperands,
316       // TODO: handle outputs.
317       consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
318       consumer.iterator_types(),
319       /*doc=*/nullptr,
320       /*library_call=*/nullptr);
321   if (!fusedOp.getShapesToLoopsMap()) {
322     // Fused op has invalid indexing maps. Typically this means something is off
323     // in the input, but going ahead here would result in verification errors.
324     // So cleanup and abort.
325     rewriter.eraseOp(fusedOp);
326     return llvm::None;
327   }
328 
329   // Construct an AffineMap from consumer loops to producer loops.
330   // consumer loop -> tensor index
331   AffineMap consumerResultIndexMap =
332       consumer.getTiedIndexingMap(consumerOpOperand);
333   // tensor index -> producer loop
334   AffineMap invProducerResultIndexMap =
335       inversePermutation(producerResultIndexMap);
336   assert(invProducerResultIndexMap &&
337          "expected producer result indexig map to be invertible");
338   // consumer loop -> producer loop
339   AffineMap consumerToProducerLoopsMap =
340       invProducerResultIndexMap.compose(consumerResultIndexMap);
341 
342   generateFusedElementwiseOpRegion(rewriter, fusedOp,
343                                    consumerToProducerLoopsMap,
344                                    consumerOpOperand, consumer.getNumLoops());
345   return SmallVector<Value>(fusedOp->getResults());
346 }
347 
348 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
349 /// provided, given the shape of the source tensor that corresponds to the
350 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
351 /// are "row-major" ordered logically.
352 ///
353 /// For example:
354 ///
355 /// %0 = op ... : tensor<?x?x4x5xf32>
356 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
357 ///
358 /// and reshape:
359 /// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] :
360 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
361 ///
362 /// would be rewritten into:
363 /// %0 = op ... : tensor<?x?x4x5xf32>
364 /// with output index_map
365 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
366 template <typename TensorReshapeOp>
367 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
368                                         TensorReshapeOp reshapeOp) {
369   constexpr bool isExpanding =
370       std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
371   ArrayRef<int64_t> sourceShape =
372       (isExpanding ? reshapeOp.getResultType().getShape()
373                    : reshapeOp.getSrcType().getShape());
374   SmallVector<AffineExpr> resultExprs;
375   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
376   MLIRContext *context = sourceMap.getContext();
377 
378   // Compute the result exprs based on the reassociation maps.
379   for (auto &indices : reshapeOp.getReassociationIndices()) {
380     // Assume that they are in-order and contiguous (already checked in
381     // verifier).
382     assert(!indices.empty());
383     SmallVector<int64_t> sizes;
384     SmallVector<AffineExpr> dimExprs;
385     for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
386                              sourceExprs.slice(indices[0], indices.size()))) {
387       if (std::get<0>(en) == 1)
388         continue;
389       sizes.push_back(std::get<0>(en));
390       dimExprs.push_back(std::get<1>(en));
391     }
392     AffineExpr linearizedExpr =
393         makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
394     resultExprs.push_back(linearizedExpr);
395   }
396   // The new affine map cannot drop unused dimension but some new symbols may
397   // have been added. Create a map with at least as many dimensions/symbols as
398   // the original affine map.
399   int64_t maxDim = -1;
400   int64_t maxSym = -1;
401   getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym);
402   unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims());
403   unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols());
404   return AffineMap::get(numDims, numSyms, resultExprs, context);
405 }
406 
407 // tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a
408 // producer). Fusing when operand has higher rank will require use of mods and
409 // divs in the indexing maps of the fused op which would make it non-invertible.
410 static bool isTensorReshapeOpFoldableByLinearization(
411     tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
412   if (!asProducer)
413     return false;
414   return useIndexMap.isPermutation();
415 }
416 
417 // tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a
418 // consumer).
419 static bool
420 isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp,
421                                          AffineMap useIndexMap,
422                                          bool asProducer) {
423   if (asProducer)
424     return false;
425   return useIndexMap.isPermutation();
426 }
427 
428 /// Check if the reshape operation is only expansion into/collapsing of
429 /// unit-dimension.
430 template <typename TensorReshapeOp>
431 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
432   constexpr bool isExpanding =
433       std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
434   ArrayRef<int64_t> expandedShape =
435       (isExpanding ? reshapeOp.getResultType().getShape()
436                    : reshapeOp.getSrcType().getShape());
437   for (auto &indices : reshapeOp.getReassociationIndices()) {
438     unsigned numUnitDims = 0;
439     for (int64_t position : indices)
440       if (expandedShape[position] == 1)
441         numUnitDims++;
442     if (numUnitDims != indices.size() - 1)
443       return false;
444   }
445   return true;
446 }
447 
448 /// Conditions for folding a generic operation with a reshape op by expanding
449 /// the iteration space dimensionality for tensor operations. These are
450 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
451 /// following fusion pattern.
452 ///
453 ///  Consider
454 ///
455 ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
456 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
457 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
458 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
459 ///  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
460 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
461 ///
462 ///  The reshape can be folded into the `genericOp` if its loop dimensionality
463 ///  is increased to match the result (operand) of the tensor_expand_shape.
464 ///  The indexing_map of the fused tensor in the `genericOp` and the
465 ///  reassociation map helps compute the indexing maps of the modified op.
466 ///  For the above example, based on the reassociation map it
467 ///  can be concluded that
468 ///
469 ///  - The loop used to access the first dimension of the fused tensor is split
470 ///    into two.
471 ///  - The loop used to access the second dimension of the fused tensor is kept
472 ///    as is.
473 ///  - The loop used to access the third dimension of the fused tensor is split
474 ///    into three.
475 ///
476 ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
477 ///  op, then
478 ///
479 ///   d0 -> e0, e1
480 ///   d1 -> e2, e3, e4
481 ///   d2 -> e5
482 ///
483 ///  substituting this, the generic op can be rewritten as
484 ///
485 ///  %d = linalg.generic ins(%0, %1 : )
486 ///        indexing_maps =
487 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
488 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
489 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
490 ///
491 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
492 ///  to make it consistent
493 ///
494 ///  %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
495 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
496 ///  %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
497 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
498 ///
499 ///  The added reshapes are again expanding patterns, so they will get fused
500 ///  with its producers if possible.
501 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
502                                                OpOperand *fusableOpOperand) {
503   // Is fusable only if:
504   // - All the indexing maps for operands and results are projected
505   //   permutations.
506   // - The fused tensor is not a scalar.
507   // - All the loops are parallel loops.
508   return genericOp.hasTensorSemantics() &&
509          llvm::all_of(genericOp.indexing_maps().getValue(),
510                       [](Attribute attr) {
511                         return attr.cast<AffineMapAttr>()
512                             .getValue()
513                             .isProjectedPermutation();
514                       }) &&
515          genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
516          llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
517            return attr.cast<StringAttr>().getValue() ==
518                   getParallelIteratorTypeName();
519          });
520 }
521 
522 namespace {
523 /// Information needed to expand a generic operation to fold the reshape with
524 /// it.
525 class ExpansionInfo {
526 public:
527   // Computes the mapping from original dimensions of the op to the dimensions
528   // of the expanded op given the `indexingMap` of the fused operand/result of
529   // the generic op, the `reassocationMaps` of the reshape op and the shape of
530   // the expanded op.
531   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
532                         ArrayRef<AffineMap> reassociationMaps,
533                         ArrayRef<int64_t> expandedShape,
534                         ArrayRef<int64_t> collapsedShape,
535                         PatternRewriter &rewriter);
536   unsigned getOrigOpNumDims() const { return reassociation.size(); }
537   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
538   ReassociationIndicesRef getExpandedDims(unsigned i) const {
539     return reassociation[i];
540   }
541   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
542     return expandedShapeMap[i];
543   }
544   ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
545 
546 private:
547   /// Reassociation from the dimensions in the original operation to the
548   /// dimension of the expanded operation.
549   SmallVector<ReassociationIndices> reassociation;
550   /// Mapping from extent of loops in the original operation, to the extent of
551   /// loops in the expanded operation.
552   SmallVector<SmallVector<int64_t>> expandedShapeMap;
553   /// Extent of the loop in the original operation.
554   SmallVector<int64_t> originalLoopExtent;
555   unsigned expandedOpNumDims;
556 };
557 } // namespace
558 
559 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
560                                      OpOperand *fusableOpOperand,
561                                      ArrayRef<AffineMap> reassociationMaps,
562                                      ArrayRef<int64_t> expandedShape,
563                                      ArrayRef<int64_t> collapsedShape,
564                                      PatternRewriter &rewriter) {
565   if (reassociationMaps.empty())
566     return failure();
567   AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
568 
569   Optional<SmallVector<int64_t, 4>> originalLoopRange =
570       linalgOp.getStaticLoopRanges();
571   if (!originalLoopRange)
572     return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
573   originalLoopExtent.assign(originalLoopRange->begin(),
574                             originalLoopRange->end());
575 
576   reassociation.clear();
577   expandedShapeMap.clear();
578   // Compute the number of dimension in the expanded op that correspond to each
579   // dimension of the original op.
580   SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
581   expandedShapeMap.resize(fusedIndexMap.getNumDims());
582   for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
583     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
584     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
585     numExpandedDims[pos] = foldedDims.getNumResults();
586     ArrayRef<int64_t> shape =
587         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
588     expandedShapeMap[pos].assign(shape.begin(), shape.end());
589   }
590   // The remaining dimensions remain the same.
591   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
592     if (expandedShapeMap[i].empty())
593       expandedShapeMap[i] = {originalLoopExtent[i]};
594 
595   // Compute reassociation map from the original op to the expanded op.
596   unsigned sum = 0;
597   reassociation.reserve(fusedIndexMap.getNumDims());
598   for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
599     auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
600     reassociation.emplace_back(seq.begin(), seq.end());
601     sum += numFoldedDim.value();
602   }
603   expandedOpNumDims = sum;
604   return success();
605 }
606 
607 /// Epanding the body of a linalg operation requires adaptations of the accessed
608 /// loop indices. Specifically, access of indices in the original operation need
609 /// to be replaced with linearizations of indices in the expanded op. That
610 /// requires the shape of the expanded dimensions to be static (at least all but
611 /// the most significant). For now check that these are all statically sized.
612 /// Note that this could be extended to handle dynamic case, but the
613 /// implementation below uses `affine.apply` which seems to have issues when the
614 /// shapes are not static.
615 LogicalResult isGenericOpExpandable(GenericOp genericOp,
616                                     const ExpansionInfo &expansionInfo,
617                                     PatternRewriter &rewriter) {
618   if (!genericOp.hasIndexSemantics())
619     return success();
620   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
621     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
622     if (expandedShape.size() == 1)
623       continue;
624     for (int64_t shape : expandedShape.drop_front()) {
625       if (ShapedType::isDynamic(shape)) {
626         return rewriter.notifyMatchFailure(
627             genericOp, "cannot expand due to index semantics and dynamic dims");
628       }
629     }
630   }
631   return success();
632 }
633 
634 /// Return the indexing map to use in the expanded op for a given the
635 /// `indexingMap` of the original operation.
636 static AffineMap
637 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
638                            const ExpansionInfo &expansionInfo) {
639   SmallVector<AffineExpr> newExprs;
640   for (AffineExpr expr : indexingMap.getResults()) {
641     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
642     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
643         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
644           return builder.getAffineDimExpr(static_cast<unsigned>(v));
645         }));
646     newExprs.append(expandedExprs.begin(), expandedExprs.end());
647   }
648   return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
649                         indexingMap.getNumSymbols(), newExprs,
650                         builder.getContext());
651 }
652 
653 /// Return the type of the operand/result to use in the expanded op given the
654 /// type in the original op.
655 static RankedTensorType getExpandedType(RankedTensorType originalType,
656                                         AffineMap indexingMap,
657                                         const ExpansionInfo &expansionInfo) {
658   SmallVector<int64_t> expandedShape;
659   for (AffineExpr expr : indexingMap.getResults()) {
660     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
661     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
662     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
663   }
664   return RankedTensorType::get(expandedShape, originalType.getElementType());
665 }
666 
667 /// Returns the reassociation maps to use in the `tensor.expand_shape`
668 /// operation to convert the operands of the original operation to operands of
669 /// the expanded operation. The same method is used to compute the
670 /// `tensor.collapse_shape` used to collapse the result of the expanded
671 /// op to get the value that can replace all uses of the results of the original
672 /// op.
673 static SmallVector<ReassociationIndices>
674 getReassociationForExpansion(AffineMap indexingMap,
675                              const ExpansionInfo &expansionInfo) {
676   SmallVector<ReassociationIndices> reassociation;
677   unsigned numReshapeDims = 0;
678   for (AffineExpr expr : indexingMap.getResults()) {
679     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
680     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
681     SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
682         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
683     reassociation.emplace_back(std::move(indices));
684     numReshapeDims += numExpandedDims;
685   }
686   return reassociation;
687 }
688 
689 /// Update the body of an expanded linalg operation having index semantics. The
690 /// indices of the original operation need to be recovered by linearizing the
691 /// indices of the correspoding dimensions of the expanded operation. For now it
692 /// is assumed that the shapes of the expanded operation needed for
693 /// linearization are static.
694 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
695                                           Location loc, Region &fusedRegion,
696                                           const ExpansionInfo &expansionInfo) {
697   // Replace the original indices by the linearization of the expanded indices.
698   for (IndexOp indexOp :
699        llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
700     ArrayRef<int64_t> expandedDims =
701         expansionInfo.getExpandedDims(indexOp.dim());
702     assert(!expandedDims.empty() && "expected valid expansion info");
703 
704     // Skip index operations that are not affected by the expansion.
705     if (expandedDims.size() == 1 &&
706         expandedDims.front() == (int64_t)indexOp.dim())
707       continue;
708 
709     // Linearize the expanded indices of the original index dimension.
710     OpBuilder::InsertionGuard guard(rewriter);
711     rewriter.setInsertionPointAfter(indexOp);
712     ArrayRef<int64_t> expandedDimsShape =
713         expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
714     SmallVector<Value> expandedIndices;
715     expandedIndices.reserve(expandedDims.size() - 1);
716     llvm::transform(
717         expandedDims.drop_front(), std::back_inserter(expandedIndices),
718         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
719     Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
720     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
721       assert(!ShapedType::isDynamic(std::get<0>(it)));
722       AffineExpr idx, acc;
723       bindDims(rewriter.getContext(), idx, acc);
724       newIndex = rewriter.create<AffineApplyOp>(
725           indexOp.getLoc(), idx + acc * std::get<0>(it),
726           ValueRange{std::get<1>(it), newIndex});
727     }
728     rewriter.replaceOp(indexOp, newIndex);
729   }
730 }
731 
732 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op
733 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
734 /// that those conditions have been satisfied.
735 static Optional<SmallVector<Value>>
736 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
737                            OpOperand *fusableOpOperand,
738                            PatternRewriter &rewriter) {
739   assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
740          "preconditions for fuse operation failed");
741   // Check if reshape is expanding or collapsing.
742   auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
743   auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
744   bool isExpanding = (expandingReshapeOp != nullptr);
745   RankedTensorType expandedType = isExpanding
746                                       ? expandingReshapeOp.getResultType()
747                                       : collapsingReshapeOp.getSrcType();
748   RankedTensorType collapsedType = isExpanding
749                                        ? expandingReshapeOp.getSrcType()
750                                        : collapsingReshapeOp.getResultType();
751 
752   ExpansionInfo expansionInfo;
753   if (failed(expansionInfo.compute(
754           genericOp, fusableOpOperand,
755           isExpanding ? expandingReshapeOp.getReassociationMaps()
756                       : collapsingReshapeOp.getReassociationMaps(),
757           expandedType.getShape(), collapsedType.getShape(), rewriter)))
758     return llvm::None;
759 
760   if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
761     return llvm::None;
762 
763   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
764       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
765         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
766       }));
767 
768   SmallVector<Value> expandedOpOperands;
769   expandedOpOperands.reserve(genericOp.getNumInputs());
770   for (OpOperand *opOperand : genericOp.getInputOperands()) {
771     if (opOperand == fusableOpOperand) {
772       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
773                                                : collapsingReshapeOp.src());
774       continue;
775     }
776     if (genericOp.isInputTensor(opOperand)) {
777       AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
778       auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
779       RankedTensorType expandedOperandType =
780           getExpandedType(opOperandType, indexingMap, expansionInfo);
781       if (expandedOperandType != opOperand->get().getType()) {
782         // Reshape the operand to get the right type.
783         SmallVector<ReassociationIndices> reassociation =
784             getReassociationForExpansion(indexingMap, expansionInfo);
785         if (failed(reshapeLikeShapesAreCompatible(
786                 [&](const Twine &msg) {
787                   return rewriter.notifyMatchFailure(genericOp, msg);
788                 },
789                 opOperandType.getShape(), expandedOperandType.getShape(),
790                 reassociation,
791                 /*isExpandingReshape=*/true)))
792           return llvm::None;
793         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
794             genericOp.getLoc(), expandedOperandType, opOperand->get(),
795             reassociation));
796         continue;
797       }
798     }
799     expandedOpOperands.push_back(opOperand->get());
800   }
801 
802   Location loc = genericOp.getLoc();
803   SmallVector<Value> outputs;
804   for (OpOperand *opOperand : genericOp.getOutputOperands()) {
805     AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
806     auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
807     RankedTensorType expandedOutputType =
808         getExpandedType(opOperandType, indexingMap, expansionInfo);
809     if (expandedOutputType != opOperand->get().getType()) {
810       SmallVector<ReassociationIndices> reassociation =
811           getReassociationForExpansion(indexingMap, expansionInfo);
812       if (failed(reshapeLikeShapesAreCompatible(
813               [&](const Twine &msg) {
814                 return rewriter.notifyMatchFailure(genericOp, msg);
815               },
816               opOperandType.getShape(), expandedOutputType.getShape(),
817               reassociation,
818               /*isExpandingReshape=*/true)))
819         return llvm::None;
820       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
821           genericOp.getLoc(), expandedOutputType, opOperand->get(),
822           reassociation));
823     }
824   }
825 
826   // The iterator types of the expanded op are all parallel.
827   SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
828                                        getParallelIteratorTypeName());
829 
830   TypeRange resultTypes = ValueRange(outputs).getTypes();
831   auto fusedOp =
832       rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
833                                  /*inputs=*/expandedOpOperands, outputs,
834                                  expandedOpIndexingMaps, iteratorTypes);
835   Region &fusedRegion = fusedOp->getRegion(0);
836   Region &originalRegion = genericOp->getRegion(0);
837   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
838 
839   // Update the index accesses after the expansion.
840   updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
841 
842   // Reshape the result values to their original shape if this is a collapsing
843   // reshape folded into its consumer.
844   SmallVector<Value> resultVals;
845   for (OpResult opResult : genericOp->getOpResults()) {
846     int64_t resultNumber = opResult.getResultNumber();
847     if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
848       SmallVector<ReassociationIndices> reassociation =
849           getReassociationForExpansion(
850               genericOp.getTiedIndexingMap(
851                   genericOp.getOutputOperand(resultNumber)),
852               expansionInfo);
853       resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
854           genericOp.getLoc(), opResult.getType(),
855           fusedOp->getResult(resultNumber), reassociation));
856     } else {
857       resultVals.push_back(fusedOp->getResult(resultNumber));
858     }
859   }
860   // Assuming a single result.
861   return resultVals;
862 }
863 
864 namespace {
865 
866 /// Pattern to fold tensor_expand_shape op with its consumer by using the source
867 /// of the reshape op as the operand in the consumer (instead of the result of
868 /// the tensor_collapse_shape). The corresponding index map in the consumer
869 /// needs to be modified to linearize the folded dimension.
870 ///
871 /// For example,
872 ///
873 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
874 /// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
875 ///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
876 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
877 ///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
878 ///        -> tensor<?x?x4x?xf32>
879 ///
880 /// can be folded into
881 ///
882 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
883 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
884 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
885 ///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
886 ///        -> tensor<?x?x4x?xf32>
887 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
888 struct FoldProducerReshapeOpByLinearization
889     : public OpRewritePattern<GenericOp> {
890   using OpRewritePattern<GenericOp>::OpRewritePattern;
891 
892   LogicalResult matchAndRewrite(GenericOp genericOp,
893                                 PatternRewriter &rewriter) const override {
894     if (!genericOp.hasTensorSemantics())
895       return failure();
896     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
897     for (const auto &en : llvm::enumerate(inputOperands)) {
898       auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
899       if (!reshapeOp)
900         continue;
901 
902       if (!isTensorReshapeOpFoldableByLinearization(
903               reshapeOp, genericOp.getTiedIndexingMap(en.value()),
904               /*asProducer =*/true) ||
905           (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
906         continue;
907 
908       // Compute the fused operands list,
909       SmallVector<Value> fusedOperands = genericOp.getInputOperands();
910       fusedOperands[en.index()] = reshapeOp.src();
911       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
912       llvm::append_range(fusedOperands, outputOperands);
913 
914       // Compute indexing_maps for the fused operation. The indexing_maps for
915       // the operands of the consumers that arent fused are the same.
916       SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
917 
918       // Accepted consumer maps are either identity or permutation.
919       auto invMap = inversePermutation(fusedIndexMaps[en.index()]);
920 
921       // Compute the indexing map to use for the result of the producer.
922       AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp);
923       // The modified map cannot have symbols.
924       if (modifiedMap.getNumSymbols())
925         return failure();
926       for (AffineExpr expr : modifiedMap.getResults()) {
927         if (!expr.isPureAffine())
928           return failure();
929       }
930       fusedIndexMaps[en.index()] = modifiedMap;
931 
932       // Further check that the resulting index maps can be fused and
933       // inverted. Without this the resultant op is not legal.
934       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
935         return rewriter.notifyMatchFailure(
936             genericOp, "fused op loop bound computation failed");
937       }
938 
939       rewriter.startRootUpdate(genericOp);
940       genericOp->setOperands(fusedOperands);
941       genericOp.indexing_mapsAttr(
942           rewriter.getAffineMapArrayAttr(fusedIndexMaps));
943       rewriter.finalizeRootUpdate(genericOp);
944       return success();
945     }
946     return failure();
947   }
948 };
949 
950 static SmallVector<ReassociationIndices>
951 getReassociationIndices(ArrayRef<AffineMap> maps) {
952   SmallVector<ReassociationIndices> reassociation;
953   for (AffineMap map : maps) {
954     ReassociationIndices indices;
955     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
956       unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
957       indices.push_back(pos);
958     }
959     reassociation.push_back(indices);
960   }
961   return reassociation;
962 }
963 
964 /// Pattern to move rank reducing reshape after an elementwise linalg generic
965 /// op. This is useful to expose more fusion opportunities between named ops and
966 /// generic ops. This can only be done if there is no broadcast or permuation
967 /// within the dimensions we need to merge.
968 ///
969 /// For example,
970 ///
971 ///  %0 = tensor.expand_shape %A [[0, 1], [2]]
972 ///      : tensor<12544x16xf32> into tensor<112x112x16xf32>
973 ///  %2 = linalg.generic {indexing_maps = [
974 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
975 ///    affine_map<(d0, d1, d2) -> (d2)>,
976 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
977 ///    ["parallel", "parallel", "parallel"]} {
978 ///  } -> tensor<112x112x16xf32>
979 ///
980 ///  into
981 ///
982 ///  %2 = linalg.generic {indexing_maps = [
983 ///    affine_map<(d0, d1) -> (d0, d1)>,
984 ///    affine_map<(d0, d1) -> (d1)>,
985 ///    affine_map<(d0, d1) -> (d0, d1)>],
986 ///    iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
987 ///    : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
988 ///  } -> tensor<12544x16xf32>
989 ///  %3 = tensor.expand_shape %2 [[0, 1], [2]]
990 ///    : tensor<12544x16xf32> into tensor<112x112x16xf32>
991 struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
992   using OpRewritePattern<GenericOp>::OpRewritePattern;
993 
994   LogicalResult matchAndRewrite(GenericOp genericOp,
995                                 PatternRewriter &rewriter) const override {
996     // Only apply to elementwise linalg on tensor.
997     if (!genericOp.hasTensorSemantics() ||
998         genericOp.getNumParallelLoops() != genericOp.getNumLoops())
999       return failure();
1000     // Only support identity output maps. It could be extended to permuations if
1001     // needed.
1002     if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
1003           return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
1004         }))
1005       return failure();
1006     int64_t destRank = genericOp.getNumParallelLoops();
1007     SmallVector<Value> newOperands = genericOp.getInputOperands();
1008     tensor::ExpandShapeOp reshapeFound;
1009     // 1. Look for tensor_expand_shape operands and figure out save the
1010     // dimensions merged.
1011     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
1012     for (const auto &en : llvm::enumerate(inputOperands)) {
1013       auto reshapeOp =
1014           en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>();
1015       if (!reshapeOp)
1016         continue;
1017       // TODO: We could support non-identity map as long as the merged
1018       // dimensions are still contiguous.
1019       if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
1020         continue;
1021       if (reshapeFound) {
1022         // Only support a second reshape op if it has the same reassociate maps.
1023         if (reshapeFound.getReassociationMaps() ==
1024             reshapeOp.getReassociationMaps())
1025           newOperands[en.index()] = reshapeOp.src();
1026         continue;
1027       }
1028       reshapeFound = reshapeOp;
1029       newOperands[en.index()] = reshapeOp.src();
1030     }
1031     if (!reshapeFound)
1032       return failure();
1033 
1034     // Calculate the reassociation indices and rassociated reverse map.
1035     SmallVector<ReassociationIndices> reassociation =
1036         getReassociationIndices(reshapeFound.getReassociationMaps());
1037     SmallVector<unsigned> remap(destRank);
1038     for (auto &indices : llvm::enumerate(reassociation)) {
1039       for (int64_t index : indices.value()) {
1040         remap[index] = indices.index();
1041       }
1042     }
1043     // 2. Verify that we can merge the dimensions in the linalg and that we
1044     // don't need to create new reshapes operands. Inserting new reshape
1045     // operands would defeat the purpose of the transformation.
1046     for (const auto &en : llvm::enumerate(inputOperands)) {
1047       if (en.value()->get() == newOperands[en.index()]) {
1048         AffineMap map = genericOp.getTiedIndexingMap(en.value());
1049         for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
1050           if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
1051             return failure();
1052         }
1053       }
1054     }
1055 
1056     // 3. Calculate the affine map remapping and the reassociation to apply to
1057     // output tensors.
1058     SmallVector<AffineMap> newMaps;
1059     unsigned newRank = reassociation.size();
1060     for (auto map : genericOp.getIndexingMaps()) {
1061       SmallVector<AffineExpr> newExprs;
1062       for (auto expr : map.getResults()) {
1063         unsigned position = expr.template cast<AffineDimExpr>().getPosition();
1064         // Skip dimension merged except for the last of the group.
1065         if (reassociation[remap[position]].back() == position) {
1066           newExprs.push_back(
1067               getAffineDimExpr(remap[position], genericOp.getContext()));
1068         }
1069       }
1070       newMaps.push_back(
1071           AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
1072     }
1073 
1074     // 4. Reshape the output tensors.
1075     SmallVector<Value> newOutputs;
1076     SmallVector<Type> newOutputTypes;
1077     for (auto output : genericOp.outputs()) {
1078       auto newOutputType = RankedTensorType::get(
1079           reshapeFound.getSrcType().getShape(),
1080           output.getType().template cast<RankedTensorType>().getElementType());
1081       Value newOutput = rewriter.create<tensor::CollapseShapeOp>(
1082           genericOp->getLoc(), newOutputType, output, reassociation);
1083       newOutputTypes.push_back(newOutputType);
1084       newOutputs.push_back(newOutput);
1085     }
1086     // 5. Create a new generic op with lowerer rank.
1087     SmallVector<StringRef> iteratorTypes(newRank,
1088                                          getParallelIteratorTypeName());
1089     auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
1090                                             newOperands, newOutputs, newMaps,
1091                                             iteratorTypes);
1092     rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
1093                                 newOp.region().begin());
1094     // 6. Reshape the so that the type matches the uses.
1095     SmallVector<Value> newResults;
1096     for (const auto &result : llvm::enumerate(newOp->getResults())) {
1097       newResults.push_back(rewriter.create<tensor::ExpandShapeOp>(
1098           genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
1099           result.value(), reassociation));
1100     }
1101     rewriter.replaceOp(genericOp, newResults);
1102     return success();
1103   }
1104 };
1105 
1106 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
1107 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
1108 /// in the consumer is expanded.
1109 class FoldWithProducerReshapeOpByExpansion
1110     : public OpRewritePattern<GenericOp> {
1111 public:
1112   FoldWithProducerReshapeOpByExpansion(
1113       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1114       PatternBenefit benefit = 1)
1115       : OpRewritePattern<GenericOp>(context, benefit),
1116         controlFoldingReshapes(std::move(foldReshapes)) {}
1117 
1118   LogicalResult matchAndRewrite(GenericOp genericOp,
1119                                 PatternRewriter &rewriter) const override {
1120     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1121       tensor::CollapseShapeOp reshapeOp =
1122           opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1123       if (!reshapeOp)
1124         continue;
1125       // Fold only if
1126       // - The tensor reshape op is folding.
1127       // - All constraints of fusing with reshape by expansion are met.
1128       if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
1129           (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
1130         continue;
1131 
1132       Optional<SmallVector<Value>> replacementValues =
1133           fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
1134       if (!replacementValues)
1135         return failure();
1136       rewriter.replaceOp(genericOp, replacementValues.getValue());
1137       return success();
1138     }
1139     return failure();
1140   }
1141 
1142 private:
1143   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1144 };
1145 
1146 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
1147 /// producer. The corresponding index map in the consumer needs to be modified
1148 /// to linearize the folded dimension.
1149 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
1150 struct FoldConsumerReshapeOpByLinearization
1151     : public OpRewritePattern<TensorReshapeOp> {
1152   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1153 
1154   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1155                                 PatternRewriter &rewriter) const override {
1156     GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
1157     if (!producer || !producer.hasTensorSemantics() ||
1158         producer.getNumOutputs() != 1 ||
1159         !isTensorReshapeOpFoldableByLinearization(
1160             reshapeOp,
1161             producer.getTiedIndexingMap(producer.getOutputOperand(0)),
1162             /*asProducer =*/false) ||
1163         (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
1164       return failure();
1165     // The indexing_maps for the operands of the fused operation are same as
1166     // those for the operands of the producer.
1167     SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
1168 
1169     auto invMap = inversePermutation(
1170         producer.getTiedIndexingMap(producer.getOutputOperand(0)));
1171 
1172     // Compute the indexing map to use for the operand of the producer.
1173     AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp);
1174     for (AffineExpr expr : modifiedMap.getResults()) {
1175       if (!expr.isPureAffine()) {
1176         return rewriter.notifyMatchFailure(
1177             producer, "fused op indexing map is not affine");
1178       }
1179     }
1180     fusedIndexMaps.back() = modifiedMap;
1181 
1182     // Further check that the resulting index maps can be fused and
1183     // inverted. Without this the resultant op is not legal.
1184     if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1185       return rewriter.notifyMatchFailure(
1186           producer, "fused op loop bound computation failed");
1187     }
1188 
1189     Location loc = producer.getLoc();
1190     SmallVector<Value> inputOperands = producer.getInputOperands();
1191     Value output = rewriter.create<TensorReshapeOp>(
1192         loc, producer.getOutputOperand(0)->get(),
1193         reshapeOp.getReassociationExprs());
1194     auto fusedOp = rewriter.create<GenericOp>(
1195         loc, reshapeOp.getResultType(),
1196         /*inputs=*/inputOperands,
1197         // TODO: handle outputs.
1198         /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1199         producer.iterator_types(),
1200         /*doc=*/nullptr,
1201         /*library_call=*/nullptr);
1202     auto &fusedRegion = fusedOp->getRegion(0);
1203     rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
1204                                fusedRegion.begin());
1205     rewriter.replaceOp(reshapeOp, fusedOp->getResults());
1206     return success();
1207   }
1208 };
1209 
1210 /// Pattern to fold a tensor_expand_shape op with its producer generic op
1211 /// by expanding the dimensionality of the loop in the producer op.
1212 struct FoldReshapeWithGenericOpByExpansion
1213     : public OpRewritePattern<tensor::ExpandShapeOp> {
1214 
1215   FoldReshapeWithGenericOpByExpansion(
1216       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1217       PatternBenefit benefit = 1)
1218       : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1219         controlFoldingReshapes(std::move(foldReshapes)) {}
1220 
1221   LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1222                                 PatternRewriter &rewriter) const override {
1223     // Fold only if all constraints of fusing with reshape by expansion are met.
1224     GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
1225     if (!producer || producer.getNumOutputs() != 1 ||
1226         !isFusableWithReshapeByDimExpansion(producer,
1227                                             producer.getOutputOperand(0)) ||
1228         !controlFoldingReshapes(producer->getResult(0),
1229                                 reshapeOp->getOpOperand(0)))
1230       return failure();
1231     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
1232         producer, reshapeOp, producer.getOutputOperand(0), rewriter);
1233     if (!replacementValues)
1234       return failure();
1235     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
1236     return success();
1237   }
1238 
1239 private:
1240   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1241 };
1242 
1243 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1244 /// handle cases where the constant is not single-valued.
1245 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1246 public:
1247   FoldScalarOrSplatConstant(MLIRContext *context,
1248                             ControlElementwiseOpsFusionFn &fun,
1249                             PatternBenefit benefit = 1)
1250       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
1251 
1252   LogicalResult matchAndRewrite(GenericOp genericOp,
1253                                 PatternRewriter &rewriter) const override {
1254     if (!genericOp.hasTensorSemantics())
1255       return failure();
1256     for (OpOperand *opOperand : genericOp.getInputOperands()) {
1257       Operation *def = opOperand->get().getDefiningOp();
1258       Attribute constantAttr;
1259       auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1260         {
1261           DenseElementsAttr splatAttr;
1262           if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1263               splatAttr.isSplat() &&
1264               splatAttr.getType().getElementType().isIntOrFloat()) {
1265             constantAttr = splatAttr.getSplatValue<Attribute>();
1266             return true;
1267           }
1268         }
1269         {
1270           IntegerAttr intAttr;
1271           if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1272             constantAttr = intAttr;
1273             return true;
1274           }
1275         }
1276         {
1277           FloatAttr floatAttr;
1278           if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1279             constantAttr = floatAttr;
1280             return true;
1281           }
1282         }
1283         return false;
1284       };
1285 
1286       auto resultValue = opOperand->get().dyn_cast<OpResult>();
1287       if (!def || !resultValue || !isScalarOrSplatConstantOp(def) ||
1288           !controlFn(resultValue, *opOperand))
1289         continue;
1290 
1291       // The operands and the indexing_maps of the fused operation the same as
1292       // the operands and indexing_maps of the generic operations with the
1293       // values at the constant index dropped.
1294       SmallVector<AffineMap> fusedIndexMaps;
1295       SmallVector<Value> fusedOperands;
1296       SmallVector<Location> fusedLocs{genericOp.getLoc()};
1297       fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1298       fusedOperands.reserve(genericOp.getNumInputs());
1299       fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
1300       for (OpOperand *inputOperand : genericOp.getInputOperands()) {
1301         if (inputOperand == opOperand)
1302           continue;
1303         Value inputValue = inputOperand->get();
1304         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1305         fusedOperands.push_back(inputValue);
1306         fusedLocs.push_back(inputValue.getLoc());
1307       }
1308       for (OpOperand *outputOperand : genericOp.getOutputOperands())
1309         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1310 
1311       // Check if the operation shapes to loops map is computable.
1312       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1313         return rewriter.notifyMatchFailure(
1314             genericOp, "fused op loop bound computation failed");
1315       }
1316 
1317       // Create a constant scalar value from the splat constant.
1318       Value scalarConstant = rewriter.create<arith::ConstantOp>(
1319           def->getLoc(), constantAttr, constantAttr.getType());
1320 
1321       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1322       auto fusedOp = rewriter.create<GenericOp>(
1323           rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1324           /*inputs=*/fusedOperands,
1325           /*outputs=*/outputOperands,
1326           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1327           genericOp.iterator_types(),
1328           /*doc=*/nullptr,
1329           /*library_call=*/nullptr);
1330 
1331       // Map the block argument corresponding to the replaced argument with the
1332       // scalar constant.
1333       Region &region = genericOp->getRegion(0);
1334       Block &entryBlock = *region.begin();
1335       BlockAndValueMapping mapping;
1336       mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1337                   scalarConstant);
1338       Region &fusedRegion = fusedOp->getRegion(0);
1339       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1340                                  mapping);
1341       rewriter.replaceOp(genericOp, fusedOp->getResults());
1342       return success();
1343     }
1344     return failure();
1345   }
1346 
1347 private:
1348   ControlElementwiseOpsFusionFn controlFn;
1349 };
1350 
1351 /// Base class for constant folding linalg.generic ops with N inputs, 1 output,
1352 /// and permutation indexing maps.
1353 ///
1354 /// `ConcreteType` should provide methods with signatures
1355 ///
1356 /// ```c++
1357 ///   bool matchIndexingMaps(GenericOp genericOp) const;
1358 ///   RegionComputationFn getRegionComputeFn(GenericOp) const;
1359 /// ```
1360 ///
1361 /// The latter inspects the region and returns the computation inside as a
1362 /// functor. The functor will be invoked with constant elements for all inputs
1363 /// and should return the corresponding computea constant element for output.
1364 template <typename ConcreteType>
1365 class FoldConstantBase : public OpRewritePattern<GenericOp> {
1366 public:
1367   struct APIntOrFloat {
1368     Optional<APInt> apInt;
1369     Optional<APFloat> apFloat;
1370   };
1371   struct APIntOrFloatArray {
1372     SmallVector<APInt> apInts;
1373     SmallVector<APFloat> apFloats;
1374   };
1375   using RegionComputationFn =
1376       std::function<APIntOrFloat(const APIntOrFloatArray &)>;
1377 
1378   FoldConstantBase(MLIRContext *context,
1379                    const ControlElementwiseOpsFusionFn &controlFn,
1380                    PatternBenefit benefit = 1)
1381       : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
1382 
1383   LogicalResult matchAndRewrite(GenericOp genericOp,
1384                                 PatternRewriter &rewriter) const override {
1385     if (genericOp.hasBufferSemantics())
1386       return failure();
1387 
1388     // Only support ops generating one output for now.
1389     if (genericOp.getNumOutputs() != 1)
1390       return failure();
1391 
1392     auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
1393     // Require the output types to be static give we are generating constants.
1394     if (!outputType || !outputType.hasStaticShape())
1395       return failure();
1396 
1397     if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
1398           return operand->get().getType().isa<ShapedType>();
1399         }))
1400       return failure();
1401 
1402     // Make sure all element types are the same.
1403     auto getOperandElementType = [](OpOperand *operand) {
1404       return operand->get().getType().cast<ShapedType>().getElementType();
1405     };
1406     if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
1407                                         getOperandElementType)))
1408       return failure();
1409 
1410     // We can only handle the case where we have int/float elements.
1411     auto elementType = outputType.getElementType();
1412     if (!elementType.isIntOrFloat())
1413       return failure();
1414 
1415     // Require all indexing maps to be permutations for now. This is common and
1416     // it simplifies input/output access greatly: we can do the data shuffling
1417     // entirely in the compiler, without needing to turn all indices into
1418     // Values, and then do affine apply on them, and then match back the
1419     // constant again.
1420     if (!llvm::all_of(genericOp.getIndexingMaps(),
1421                       [](AffineMap map) { return map.isPermutation(); }))
1422       return failure();
1423 
1424     for (OpOperand *operand : genericOp.getOutputOperands()) {
1425       if (genericOp.payloadUsesValueFromOperand(operand))
1426         return failure();
1427     }
1428 
1429     // Further check the indexing maps are okay for the ConcreteType.
1430     if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
1431       return failure();
1432 
1433     // Defer to the concrete type to check the region and discover the
1434     // computation inside.
1435     RegionComputationFn computeFn =
1436         static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
1437     if (!computeFn)
1438       return failure();
1439 
1440     // All inputs should be constants.
1441     int numInputs = genericOp.getNumInputs();
1442     SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
1443     for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) {
1444       if (!matchPattern(operand.value()->get(),
1445                         m_Constant(&inputValues[operand.index()])))
1446         return failure();
1447     }
1448 
1449     // Identified this as a potential candidate for folding. Now check the
1450     // policy to see whether we are allowed to proceed.
1451     for (int i = 0; i < numInputs; ++i) {
1452       OpOperand *consumer = genericOp.getInputOperand(i);
1453       OpResult producer = consumer->get().cast<OpResult>();
1454       if (!controlFn(producer, *consumer))
1455         return failure();
1456     }
1457 
1458     auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
1459     SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
1460     int64_t numElements = outputType.getNumElements();
1461 
1462     // Use APInt/APFloat instead of Attribute here for constructing the output.
1463     // This helps to avoid blowing up compiler memory usage: Attributes would
1464     // unify the following cases but they have lifetime as the MLIRContext.
1465     SmallVector<APInt> intOutputValues;
1466     SmallVector<APFloat> fpOutputValues;
1467     if (elementType.template isa<FloatType>())
1468       fpOutputValues.resize(numElements, APFloat(0.f));
1469     else
1470       intOutputValues.resize(numElements);
1471 
1472     // Return the constant dim positions from the given permutation map.
1473     auto getDimPositions = [](AffineMap map) {
1474       SmallVector<unsigned> dims;
1475       dims.reserve(map.getNumResults());
1476       for (AffineExpr result : map.getResults()) {
1477         dims.push_back(result.cast<AffineDimExpr>().getPosition());
1478       }
1479       return dims;
1480     };
1481 
1482     SmallVector<SmallVector<unsigned>> inputDims;
1483     for (int i = 0; i < numInputs; ++i)
1484       inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i]));
1485     auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
1486     auto outputShape = outputType.getShape();
1487 
1488     // Allocate small vectors for index delinearization. Initial values do not
1489     // matter here as they will be overwritten later.
1490     SmallVector<uint64_t> indices(loopBounds.size(), 0);
1491     SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
1492     SmallVector<SmallVector<uint64_t>> srcIndices(
1493         numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
1494     SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
1495     uint64_t dstLinearIndex = 0;
1496 
1497     // Allocate spaces for compute function inputs. Initial values do not matter
1498     // here as they will be overwritten later.
1499     APIntOrFloatArray computeFnInputs;
1500 
1501     auto inputShapes = llvm::to_vector<4>(
1502         llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
1503           return operand->get().getType().cast<ShapedType>().getShape();
1504         }));
1505 
1506     // Given a `linearIndex`, remap it to a linear index to access linalg op
1507     // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
1508     // `srcLinearIndices`, `dstLinearIndex` in place.
1509     auto computeRemappedLinearIndex = [&](int linearIndex) {
1510       int totalCount = linearIndex;
1511       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
1512         indices[dim] = totalCount % loopBounds[dim];
1513         totalCount /= loopBounds[dim];
1514       }
1515 
1516       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
1517         for (int i = 0; i < numInputs; ++i)
1518           srcIndices[i][dim] = indices[inputDims[i][dim]];
1519         dstIndices[dim] = indices[outputDims[dim]];
1520       }
1521 
1522       dstLinearIndex = dstIndices.front();
1523       for (int i = 0; i < numInputs; ++i)
1524         srcLinearIndices[i] = srcIndices[i].front();
1525 
1526       for (int dim = 1; dim < outputType.getRank(); ++dim) {
1527         dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
1528         for (int i = 0; i < numInputs; ++i)
1529           srcLinearIndices[i] =
1530               srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
1531       }
1532     };
1533 
1534     bool isFloat = elementType.isa<FloatType>();
1535     if (isFloat) {
1536       SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
1537       for (int i = 0; i < numInputs; ++i)
1538         inFpRanges.push_back(inputValues[i].getValues<APFloat>());
1539 
1540       computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
1541 
1542       // Transpose the input constant. Because we don't know its rank in
1543       // advance, we need to loop over the range [0, element count) and
1544       // delinearize the index.
1545       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
1546         computeRemappedLinearIndex(linearIndex);
1547 
1548         // Collect constant elements for all inputs at this loop iteration.
1549         for (int i = 0; i < numInputs; ++i)
1550           computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
1551 
1552         // Invoke the computation to get the corresponding constant output
1553         // element.
1554         fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
1555       }
1556     } else {
1557       SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
1558       for (int i = 0; i < numInputs; ++i)
1559         inIntRanges.push_back(inputValues[i].getValues<APInt>());
1560 
1561       computeFnInputs.apInts.resize(numInputs);
1562 
1563       // Transpose the input constant. Because we don't know its rank in
1564       // advance, we need to loop over the range [0, element count) and
1565       // delinearize the index.
1566       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
1567         computeRemappedLinearIndex(linearIndex);
1568 
1569         // Collect constant elements for all inputs at this loop iteration.
1570         for (int i = 0; i < numInputs; ++i)
1571           computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
1572 
1573         // Invoke the computation to get the corresponding constant output
1574         // element.
1575         intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
1576       }
1577     }
1578 
1579     DenseElementsAttr outputAttr =
1580         isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
1581                 : DenseElementsAttr::get(outputType, intOutputValues);
1582 
1583     rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr);
1584     return success();
1585   }
1586 
1587 private:
1588   ControlElementwiseOpsFusionFn controlFn;
1589 };
1590 
1591 // Folds linalg.generic ops that are actually transposes on constant values.
1592 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
1593   using FoldConstantBase::FoldConstantBase;
1594 
1595   bool matchIndexingMaps(GenericOp genericOp) const {
1596     // We should have one input and one output.
1597     return genericOp.getIndexingMaps().size() == 2;
1598   }
1599 
1600   RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
1601     // Make sure the region only contains a yield op.
1602     Block &body = genericOp.region().front();
1603     if (!llvm::hasSingleElement(body))
1604       return nullptr;
1605     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1606     if (!yieldOp)
1607       return nullptr;
1608 
1609     // The yield op should return the block argument corresponds to the input.
1610     for (Value yieldVal : yieldOp.values()) {
1611       auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
1612       if (!yieldArg || yieldArg.getOwner() != &body)
1613         return nullptr;
1614       if (yieldArg.getArgNumber() != 0)
1615         return nullptr;
1616     }
1617 
1618     // No computation; just return the orginal value.
1619     return [](const APIntOrFloatArray &inputs) {
1620       if (inputs.apFloats.empty())
1621         return APIntOrFloat{inputs.apInts.front(), llvm::None};
1622       return APIntOrFloat{llvm::None, inputs.apFloats.front()};
1623     };
1624   }
1625 
1626   ControlElementwiseOpsFusionFn controlFn;
1627 };
1628 
1629 } // namespace
1630 
1631 static Optional<SmallVector<Value>>
1632 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
1633                    GenericOp producer,
1634                    const ControlElementwiseOpsFusionFn &controlFn) {
1635   if (producer->getNumResults() != 1)
1636     return llvm::None;
1637 
1638   return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
1639                                 rewriter);
1640 }
1641 
1642 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
1643                                       OpOperand &consumer) {
1644   if (auto producerCollapseOp =
1645           dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
1646     return !isUnitDimExpansionOnly(producerCollapseOp);
1647   }
1648   if (auto consumerExpandOp =
1649           dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
1650     return !isUnitDimExpansionOnly(consumerExpandOp);
1651   }
1652   return true;
1653 }
1654 
1655 namespace {
1656 /// Patterns to fuse a generic op, with the producer of its operands.
1657 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
1658 public:
1659   FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
1660                      PatternBenefit benefit = 1)
1661       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
1662 
1663   LogicalResult matchAndRewrite(GenericOp genericOp,
1664                                 PatternRewriter &rewriter) const override {
1665     // Find the first operand that is defined by another generic op on tensors.
1666     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
1667       auto producer =
1668           dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
1669       if (!producer || !producer.hasTensorSemantics())
1670         continue;
1671       Optional<SmallVector<Value>> fusedOpResults =
1672           fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
1673       if (fusedOpResults) {
1674         rewriter.replaceOp(genericOp, *fusedOpResults);
1675         return success();
1676       }
1677     }
1678     return failure();
1679   }
1680 
1681 private:
1682   ControlElementwiseOpsFusionFn controlFn;
1683 };
1684 
1685 /// Pass that fuses generic ops on tensors. Used only for testing.
1686 struct LinalgElementwiseOpFusionPass
1687     : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
1688   void runOnOperation() override {
1689     Operation *op = getOperation();
1690     RewritePatternSet patterns(op->getContext());
1691     ControlElementwiseOpsFusionFn allowFoldingFn =
1692         [](const OpResult &producer, const OpOperand &consumer) {
1693           return true;
1694         };
1695     populateElementwiseOpsFusionPatterns(
1696         patterns,
1697         LinalgElementwiseFusionOptions().setControlFoldingReshapes(
1698             allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
1699 
1700     // Use TopDownTraversal for compile time reasons
1701     GreedyRewriteConfig grc;
1702     grc.useTopDownTraversal = true;
1703     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
1704                                        grc);
1705   }
1706 };
1707 
1708 /// Pass to test folding of reshape ops with generic ops by linearization.
1709 struct FoldReshapeOpsByLinearizationPass
1710     : public LinalgFoldReshapeOpsByLinearizationBase<
1711           FoldReshapeOpsByLinearizationPass> {
1712   void runOnOperation() override {
1713     Operation *op = getOperation();
1714     RewritePatternSet patterns(op->getContext());
1715     populateFoldReshapeOpsByLinearizationPatterns(patterns);
1716     if (allowFoldingUnitDimReshapes) {
1717       populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
1718     }
1719     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
1720   }
1721 };
1722 
1723 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
1724 /// the value of the `outs` operand is not used within the op.  This is only
1725 /// implemented for `linalg.generic` operations for now, but should hold for all
1726 /// linalg structured ops.
1727 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1728   using OpRewritePattern<GenericOp>::OpRewritePattern;
1729 
1730   LogicalResult matchAndRewrite(GenericOp op,
1731                                 PatternRewriter &rewriter) const override {
1732     rewriter.startRootUpdate(op);
1733     bool modifiedOutput = false;
1734     Location loc = op.getLoc();
1735     for (OpOperand *opOperand : op.getOutputOperands()) {
1736       if (!op.payloadUsesValueFromOperand(opOperand)) {
1737         Value operandVal = opOperand->get();
1738         auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
1739         if (!operandType)
1740           continue;
1741 
1742         // If outs is already an `init_tensor` operation, nothing to do.
1743         auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
1744         if (definingOp)
1745           continue;
1746         modifiedOutput = true;
1747         SmallVector<Value> dynamicDims;
1748         for (const auto &dim : llvm::enumerate(operandType.getShape())) {
1749           if (dim.value() != ShapedType::kDynamicSize)
1750             continue;
1751           dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
1752               loc, operandVal, dim.index()));
1753         }
1754         Value initTensor = rewriter.create<InitTensorOp>(
1755             loc, dynamicDims, operandType.getShape(),
1756             operandType.getElementType());
1757         op->setOperand(opOperand->getOperandNumber(), initTensor);
1758       }
1759     }
1760     if (!modifiedOutput) {
1761       rewriter.cancelRootUpdate(op);
1762       return failure();
1763     }
1764     rewriter.finalizeRootUpdate(op);
1765     return success();
1766   }
1767 };
1768 
1769 } // namespace
1770 
1771 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
1772     RewritePatternSet &patterns) {
1773   patterns
1774       .add<FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
1775            FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
1776            FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
1777            FoldConsumerReshapeOpByLinearization<false, tensor::ExpandShapeOp>>(
1778           patterns.getContext());
1779 }
1780 
1781 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
1782     RewritePatternSet &patterns) {
1783   patterns
1784       .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
1785            FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>,
1786            FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
1787            FoldConsumerReshapeOpByLinearization<true, tensor::ExpandShapeOp>>(
1788           patterns.getContext());
1789 }
1790 
1791 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
1792     RewritePatternSet &patterns,
1793     const ControlElementwiseOpsFusionFn &controlFoldingReshapes) {
1794   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
1795                                                     controlFoldingReshapes);
1796   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
1797                                                      controlFoldingReshapes);
1798 }
1799 
1800 void mlir::linalg::populateElementwiseOpsFusionPatterns(
1801     RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
1802   auto *context = patterns.getContext();
1803   patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
1804                FoldConstantTranspose>(context,
1805                                       options.controlElementwiseOpsFusionFn);
1806   patterns.add<RemoveOutsDependency>(context);
1807   populateFoldReshapeOpsByExpansionPatterns(patterns,
1808                                             options.controlFoldingReshapesFn);
1809   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1810   GenericOp::getCanonicalizationPatterns(patterns, context);
1811   tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1812   tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1813   context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1814       patterns);
1815 }
1816 
1817 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
1818   auto *context = patterns.getContext();
1819   patterns.add<PushExpandingReshape>(context);
1820 }
1821 
1822 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
1823   return std::make_unique<LinalgElementwiseOpFusionPass>();
1824 }
1825 
1826 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
1827   return std::make_unique<FoldReshapeOpsByLinearizationPass>();
1828 }
1829