1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "PassDetail.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Passes.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/IR/AffineExpr.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Support/LLVM.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 
28 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
29 /// the `producer` to use in the fused operation given the indexing map of the
30 /// result of the producer in the consumer.
31 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
32     OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
33     AffineMap fusedConsumerArgIndexMap) {
34   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
35   // from consumer loop -> consumer arg tensor index/producer result tensor
36   // index. The fused loop is same as the consumer loop. For each producer arg
37   // the indexing map to be computed is a map from consumer loop -> producer
38   // arg tensor index.
39   // producerResultIndexMap is a map from producer loop -> tensor index.
40   // Compute the inverse to get map from tensor index -> producer loop.
41   // The inverse is a map from producer result tensor index -> producer loop.
42   AffineMap invProducerResultIndexMap =
43       inversePermutation(producerResultIndexMap);
44   assert(invProducerResultIndexMap &&
45          "expected producer result indexig map to be invertible");
46 
47   LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
48   // argMap is a map from producer loop -> producer arg tensor index.
49   AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
50 
51   // Compose argMap with invProducerResultIndexMap to get a map from
52   // producer result tensor index -> producer arg tensor index.
53   AffineMap t1 = argMap.compose(invProducerResultIndexMap);
54 
55   // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
56   // consumer loop/ fused loop -> producer arg tensor index.
57   return t1.compose(fusedConsumerArgIndexMap);
58 }
59 
60 /// Conditions for elementwise fusion of generic operations.
61 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
62                                      OpOperand *consumerOpOperand) {
63   // Producer and consumer must have tensor semantics.
64   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
65     return false;
66 
67   // Verify that
68   // - the producer has all "parallel" iterator type.
69   if (producer.getNumParallelLoops() != producer.getNumLoops())
70     return false;
71 
72   // Only allow fusing the producer of an input operand for now.
73   // TODO: allow fusing the producer of an output operand.
74   if (!consumer.isInputTensor(consumerOpOperand))
75     return false;
76 
77   // Get the consumer index map. The number of results of the consumer index
78   // map must match the number of loops of the producer.
79   AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
80   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
81     return false;
82 
83   // Currently support only operations with single result.
84   if (producer.getNumOutputs() != 1)
85     return false;
86 
87   // Finally the index_map for the result must be invertible. For now just
88   // verify it is a permutation.
89   AffineMap producerResultIndexMap =
90       producer.getTiedIndexingMap(producer.getOutputOperand(0));
91   if (!producerResultIndexMap.isPermutation())
92     return false;
93 
94   // Ensure that the fusion does not remove size information required to
95   // get the loop bounds. For non-reduction generics, this is trivially the
96   // case due to the output operand. For reductions, we need to check that after
97   // the fusion, each loop dimension has at least one input that defines it.
98   if ((consumer.getNumReductionLoops())) {
99     llvm::BitVector coveredDims(consumer.getNumLoops(), false);
100 
101     auto addToCoveredDims = [&](AffineMap map) {
102       for (auto result : map.getResults())
103         if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
104           coveredDims[dimExpr.getPosition()] = true;
105     };
106 
107     for (auto pair :
108          llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) {
109       Value operand = std::get<0>(pair);
110       if (operand == consumerOpOperand->get())
111         continue;
112       AffineMap operandMap = std::get<1>(pair);
113       addToCoveredDims(operandMap);
114     }
115 
116     for (OpOperand *operand : producer.getInputOperands()) {
117       AffineMap newIndexingMap =
118           getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
119               operand, producerResultIndexMap, consumerIndexMap);
120       addToCoveredDims(newIndexingMap);
121     }
122     if (!coveredDims.all())
123       return false;
124   }
125 
126   return true;
127 }
128 
129 /// Generate the region of the fused tensor operation. The region of the fused
130 /// op must be empty.
131 static void
132 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
133                                  AffineMap consumerToProducerLoopsMap,
134                                  OpOperand *consumerOpOperand,
135                                  unsigned nloops) {
136   auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
137   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
138   // Build the region of the fused op.
139   Block &producerBlock = producer->getRegion(0).front();
140   Block &consumerBlock = consumer->getRegion(0).front();
141   Block *fusedBlock = new Block();
142   fusedOp.region().push_back(fusedBlock);
143   BlockAndValueMapping mapper;
144   OpBuilder::InsertionGuard guard(rewriter);
145   rewriter.setInsertionPointToStart(fusedBlock);
146 
147   // 2. Add an index operation for every fused loop dimension and use the
148   // `consumerToProducerLoopsMap` to map the producer indices.
149   if (producer.hasIndexSemantics()) {
150     // Add an index operation for every fused loop dimension.
151     unsigned numFusedOpLoops =
152         std::max(producer.getNumLoops(), consumer.getNumLoops());
153     SmallVector<Value> fusedIndices;
154     fusedIndices.reserve(numFusedOpLoops);
155     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
156                     std::back_inserter(fusedIndices), [&](uint64_t dim) {
157                       return rewriter.create<IndexOp>(producer.getLoc(), dim);
158                     });
159     for (IndexOp indexOp :
160          llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
161       Value newIndex = rewriter.create<mlir::AffineApplyOp>(
162           producer.getLoc(),
163           consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
164       mapper.map(indexOp.getResult(), newIndex);
165     }
166   }
167   // TODO: allow fusing the producer of an output operand.
168   assert(consumer.isInputTensor(consumerOpOperand) &&
169          "expected producer of input operand");
170   // 3. Consumer input operands up to consumerIdx (exclusive).
171   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
172            consumerOpOperand->getOperandNumber())) // input assumption.
173     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
174 
175   // Replacing consumerIdx requires getting the cloned, yielded, value from
176   // the (cloned) producer block. This happens in step 9.
177 
178   // 4. Splice in producer's input operands.
179   for (BlockArgument bbArg :
180        producerBlock.getArguments().take_front(producer.getNumInputs()))
181     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
182 
183   // 4.b. Producer output operand/map that is fused needs to be mapped to the
184   // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
185   assert(producer->getNumResults() == 1 && "expected single result producer");
186   if (producer.isInitTensor(producer.getOutputOperand(0))) {
187     BlockArgument bbArg = producerBlock.getArguments()
188                               .drop_front(producer.getNumInputs())
189                               // TODO: bbArg index of
190                               .front();
191     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
192   }
193   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
194   for (BlockArgument bbArg :
195        consumerBlock.getArguments()
196            .take_front(consumer.getNumInputs())
197            .drop_front(consumerOpOperand->getOperandNumber() + 1))
198     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
199   // 6. All of consumer's output operands.
200   for (BlockArgument bbArg :
201        consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
202     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
203   // 7. All of producer's output operands except the one fused.
204   // TODO: allow fusion of multi-result producers.
205   assert(producer->getNumResults() == 1 && "expected single result producer");
206 
207   // 8. Clone all producer operations except for the yield and index operations
208   // to the fused operation.
209   for (auto &op : producerBlock.without_terminator()) {
210     if (!isa<IndexOp>(op))
211       rewriter.clone(op, mapper);
212   }
213   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
214   // forward the yield operand.
215   auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
216   // TODO: allow fusion of multi-result producers.
217   assert(producer->getNumResults() == 1 && "expected single result producer");
218   unsigned producerResultNumber = 0;
219   Value replacement =
220       mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
221   // Sanity checks, if replacement is not already in the mapper then it must be
222   // produced outside.
223   if (replacement == yieldOp.getOperand(producerResultNumber)) {
224     if (auto bb = replacement.dyn_cast<BlockArgument>())
225       assert(bb.getOwner() != &producerBlock &&
226              "yielded block argument must have been mapped");
227     else
228       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
229              "yielded value must have been mapped");
230   }
231   mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
232              replacement);
233   // 10. Clone operations from the consumer to the fused op.
234   for (auto &op : consumerBlock.getOperations())
235     rewriter.clone(op, mapper);
236 
237   // Sanity checks.
238   assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
239          "Ill-formed GenericOp region");
240 }
241 
242 static Optional<SmallVector<Value>>
243 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
244                        const ControlElementwiseOpsFusionFn &controlFn,
245                        PatternRewriter &rewriter) {
246   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
247   if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
248       !controlFn(producer->getResult(0), *consumerOpOperand))
249     return llvm::None;
250 
251   // TODO: allow fusing the producer of an output operand.
252   assert(consumer.isInputTensor(consumerOpOperand) &&
253          "expected producer of input operand");
254 
255   // Compute the fused operands list and indexing maps.
256   SmallVector<Value> fusedOperands;
257   SmallVector<AffineMap> fusedIndexMaps;
258   fusedOperands.reserve(producer->getNumOperands() +
259                         consumer->getNumOperands());
260   fusedIndexMaps.reserve(producer->getNumOperands() +
261                          consumer->getNumOperands());
262   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
263   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
264   SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
265   SmallVector<OpOperand *>::iterator it =
266       llvm::find(consumerInputs, consumerOpOperand);
267   assert(it != consumerInputs.end() && "expected to find the consumer operand");
268   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
269     fusedOperands.push_back(opOperand->get());
270     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
271   }
272   // 4. Splice in producer's input operands/maps.
273   assert(producer->getNumResults() == 1 && "expected single result producer");
274   AffineMap producerResultIndexMap =
275       producer.getTiedIndexingMap(producer.getOutputOperand(0));
276   for (OpOperand *opOperand : producer.getInputOperands()) {
277     fusedOperands.push_back(opOperand->get());
278     // Compute indexing maps for the producer args in the fused operation.
279     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
280         opOperand, producerResultIndexMap,
281         consumer.getTiedIndexingMap(consumerOpOperand));
282     fusedIndexMaps.push_back(map);
283   }
284   // 4.b. Producer output operand/map that is fused needs to be passed if it is
285   // an "initTensor" (i.e. its value is actually read).
286   assert(producer->getNumResults() == 1 && "expected single result producer");
287   if (producer.isInitTensor(producer.getOutputOperand(0))) {
288     fusedOperands.push_back(producer.getOutputOperand(0)->get());
289     // Compute indexing maps for the producer args in the fused operation.
290     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
291         producer.getOutputOperand(0), producerResultIndexMap,
292         consumer.getTiedIndexingMap(consumerOpOperand));
293     fusedIndexMaps.push_back(map);
294   }
295   // 5. Remaining consumer's input operands/maps (drop past index
296   // `consumerIdx`).
297   for (OpOperand *opOperand :
298        llvm::make_range(std::next(it), consumerInputs.end())) {
299     fusedOperands.push_back(opOperand->get());
300     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
301   }
302   // 6. All of consumer's output operands (skip operands: added by the builder).
303   for (OpOperand *opOperand : consumer.getOutputOperands())
304     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
305   // 7. All of producer's output operands/maps except the one fused.
306   // TODO: allow fusion of multi-result producers.
307   assert(producer->getNumResults() == 1 && "expected single result producer");
308 
309   // Generate the fused op.
310   SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
311   auto fusedOp = rewriter.create<GenericOp>(
312       consumer.getLoc(), consumer->getResultTypes(),
313       /*inputs=*/fusedOperands,
314       // TODO: handle outputs.
315       consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
316       consumer.iterator_types(),
317       /*doc=*/nullptr,
318       /*library_call=*/nullptr);
319 
320   // Construct an AffineMap from consumer loops to producer loops.
321   // consumer loop -> tensor index
322   AffineMap consumerResultIndexMap =
323       consumer.getTiedIndexingMap(consumerOpOperand);
324   // tensor index -> producer loop
325   AffineMap invProducerResultIndexMap =
326       inversePermutation(producerResultIndexMap);
327   assert(invProducerResultIndexMap &&
328          "expected producer result indexig map to be invertible");
329   // consumer loop -> producer loop
330   AffineMap consumerToProducerLoopsMap =
331       invProducerResultIndexMap.compose(consumerResultIndexMap);
332 
333   generateFusedElementwiseOpRegion(rewriter, fusedOp,
334                                    consumerToProducerLoopsMap,
335                                    consumerOpOperand, consumer.getNumLoops());
336   return SmallVector<Value>(fusedOp->getResults());
337 }
338 
339 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
340 /// provided, given the shape of the source tensor that corresponds to the
341 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
342 /// are "row-major" ordered logically.
343 ///
344 /// For example:
345 ///
346 /// %0 = op ... : tensor<?x?x4x5xf32>
347 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
348 ///
349 /// and reshape:
350 /// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] :
351 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
352 ///
353 /// would be rewritten into:
354 /// %0 = op ... : tensor<?x?x4x5xf32>
355 /// with output index_map
356 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
357 template <typename TensorReshapeOp>
358 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
359                                         TensorReshapeOp reshapeOp) {
360   constexpr bool isExpanding =
361       std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
362   ArrayRef<int64_t> sourceShape =
363       (isExpanding ? reshapeOp.getResultType().getShape()
364                    : reshapeOp.getSrcType().getShape());
365   SmallVector<AffineExpr> resultExprs;
366   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
367   MLIRContext *context = sourceMap.getContext();
368 
369   // Compute the result exprs based on the reassociation maps.
370   for (auto &indices : reshapeOp.getReassociationIndices()) {
371     // Assume that they are in-order and contiguous (already checked in
372     // verifier).
373     assert(!indices.empty());
374     SmallVector<int64_t> sizes;
375     SmallVector<AffineExpr> dimExprs;
376     for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
377                              sourceExprs.slice(indices[0], indices.size()))) {
378       if (std::get<0>(en) == 1)
379         continue;
380       sizes.push_back(std::get<0>(en));
381       dimExprs.push_back(std::get<1>(en));
382     }
383     AffineExpr linearizedExpr =
384         makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
385     resultExprs.push_back(linearizedExpr);
386   }
387   // The new affine map cannot drop unused dimension but some new symbols may
388   // have been added. Create a map with at least as many dimensions/symbols as
389   // the original affine map.
390   int64_t maxDim = -1;
391   int64_t maxSym = -1;
392   getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym);
393   unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims());
394   unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols());
395   return AffineMap::get(numDims, numSyms, resultExprs, context);
396 }
397 
398 // tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a
399 // producer). Fusing when operand has higher rank will require use of mods and
400 // divs in the indexing maps of the fused op which would make it non-invertible.
401 static bool isTensorReshapeOpFoldableByLinearization(
402     tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
403   if (!asProducer)
404     return false;
405   return useIndexMap.isPermutation();
406 }
407 
408 // tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a
409 // consumer).
410 static bool
411 isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp,
412                                          AffineMap useIndexMap,
413                                          bool asProducer) {
414   if (asProducer)
415     return false;
416   return useIndexMap.isPermutation();
417 }
418 
419 /// Check if the reshape operation is only expansion into/collapsing of
420 /// unit-dimension.
421 template <typename TensorReshapeOp>
422 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
423   constexpr bool isExpanding =
424       std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
425   ArrayRef<int64_t> expandedShape =
426       (isExpanding ? reshapeOp.getResultType().getShape()
427                    : reshapeOp.getSrcType().getShape());
428   for (auto &indices : reshapeOp.getReassociationIndices()) {
429     unsigned numUnitDims = 0;
430     for (int64_t position : indices)
431       if (expandedShape[position] == 1)
432         numUnitDims++;
433     if (numUnitDims != indices.size() - 1)
434       return false;
435   }
436   return true;
437 }
438 
439 /// Conditions for folding a generic operation with a reshape op by expanding
440 /// the iteration space dimensionality for tensor operations. These are
441 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
442 /// following fusion pattern.
443 ///
444 ///  Consider
445 ///
446 ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
447 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
448 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
449 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
450 ///  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
451 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
452 ///
453 ///  The reshape can be folded into the `genericOp` if its loop dimensionality
454 ///  is increased to match the result (operand) of the tensor_expand_shape.
455 ///  The indexing_map of the fused tensor in the `genericOp` and the
456 ///  reassociation map helps compute the indexing maps of the modified op.
457 ///  For the above example, based on the reassociation map it
458 ///  can be concluded that
459 ///
460 ///  - The loop used to access the first dimension of the fused tensor is split
461 ///    into two.
462 ///  - The loop used to access the second dimension of the fused tensor is kept
463 ///    as is.
464 ///  - The loop used to access the third dimension of the fused tensor is split
465 ///    into three.
466 ///
467 ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
468 ///  op, then
469 ///
470 ///   d0 -> e0, e1
471 ///   d1 -> e2, e3, e4
472 ///   d2 -> e5
473 ///
474 ///  substituting this, the generic op can be rewritten as
475 ///
476 ///  %d = linalg.generic ins(%0, %1 : )
477 ///        indexing_maps =
478 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
479 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
480 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
481 ///
482 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
483 ///  to make it consistent
484 ///
485 ///  %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
486 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
487 ///  %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
488 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
489 ///
490 ///  The added reshapes are again expanding patterns, so they will get fused
491 ///  with its producers if possible.
492 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
493                                                OpOperand *fusableOpOperand) {
494   // Is fusable only if:
495   // - All the indexing maps for operands and results are projected
496   //   permutations.
497   // - The fused tensor is not a scalar.
498   // - All the loops are parallel loops.
499   return genericOp.hasTensorSemantics() &&
500          llvm::all_of(genericOp.indexing_maps().getValue(),
501                       [](Attribute attr) {
502                         return attr.cast<AffineMapAttr>()
503                             .getValue()
504                             .isProjectedPermutation();
505                       }) &&
506          genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
507          llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
508            return attr.cast<StringAttr>().getValue() ==
509                   getParallelIteratorTypeName();
510          });
511 }
512 
513 namespace {
514 /// Information needed to expand a generic operation to fold the reshape with
515 /// it.
516 class ExpansionInfo {
517 public:
518   // Computes the mapping from original dimensions of the op to the dimensions
519   // of the expanded op given the `indexingMap` of the fused operand/result of
520   // the generic op, the `reassocationMaps` of the reshape op and the shape of
521   // the expanded op.
522   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
523                         ArrayRef<AffineMap> reassociationMaps,
524                         ArrayRef<int64_t> expandedShape,
525                         PatternRewriter &rewriter);
526   unsigned getOrigOpNumDims() const { return reassociation.size(); }
527   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
528   ReassociationIndicesRef getExpandedDims(unsigned i) const {
529     return reassociation[i];
530   }
531   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
532     return expandedShapeMap[i];
533   }
534 
535 private:
536   /// Reassociation from the dimensions in the original operation to the
537   /// dimension of the expanded operation.
538   SmallVector<ReassociationIndices> reassociation;
539   /// Mapping from extent of loops in the original operation, to the extent of
540   /// loops in the expanded operation.
541   SmallVector<SmallVector<int64_t>> expandedShapeMap;
542   unsigned expandedOpNumDims;
543 };
544 } // namespace
545 
546 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
547                                      OpOperand *fusableOpOperand,
548                                      ArrayRef<AffineMap> reassociationMaps,
549                                      ArrayRef<int64_t> expandedShape,
550                                      PatternRewriter &rewriter) {
551   if (reassociationMaps.empty())
552     return failure();
553   AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
554 
555   Optional<SmallVector<int64_t, 4>> originalLoopRange =
556       linalgOp.getStaticLoopRanges();
557   if (!originalLoopRange)
558     return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
559 
560   reassociation.clear();
561   expandedShapeMap.clear();
562   // Compute the number of dimension in the expanded op that correspond to each
563   // dimension of the original op.
564   SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
565   expandedShapeMap.resize(fusedIndexMap.getNumDims());
566   for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
567     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
568     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
569     numExpandedDims[pos] = foldedDims.getNumResults();
570     ArrayRef<int64_t> shape =
571         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
572     expandedShapeMap[pos].assign(shape.begin(), shape.end());
573   }
574   // The remaining dimensions remain the same.
575   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
576     if (expandedShapeMap[i].empty())
577       expandedShapeMap[i] = {(*originalLoopRange)[i]};
578 
579   // Compute reassociation map from the original op to the expanded op.
580   unsigned sum = 0;
581   reassociation.reserve(fusedIndexMap.getNumDims());
582   for (auto numFoldedDim : llvm::enumerate(numExpandedDims)) {
583     auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
584     reassociation.emplace_back(seq.begin(), seq.end());
585     sum += numFoldedDim.value();
586   }
587   expandedOpNumDims = sum;
588   return success();
589 }
590 
591 /// Epanding the body of a linalg operation requires adaptations of the accessed
592 /// loop indices. Specifically, access of indices in the original operation need
593 /// to be replaced with linearizations of indices in the expanded op. That
594 /// requires the shape of the expanded dimensions to be static (at least all but
595 /// the most significant). For now check that these are all statically sized.
596 /// Note that this could be extended to handle dynamic case, but the
597 /// implementation below uses `affine.apply` which seems to have issues when the
598 /// shapes are not static.
599 LogicalResult isGenericOpExpandable(GenericOp genericOp,
600                                     const ExpansionInfo &expansionInfo,
601                                     PatternRewriter &rewriter) {
602   if (!genericOp.hasIndexSemantics())
603     return success();
604   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
605     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
606     if (expandedShape.size() == 1)
607       continue;
608     for (int64_t shape : expandedShape.drop_front()) {
609       if (ShapedType::isDynamic(shape)) {
610         return rewriter.notifyMatchFailure(
611             genericOp, "cannot expand due to index semantics and dynamic dims");
612       }
613     }
614   }
615   return success();
616 }
617 
618 /// Return the indexing map to use in the expanded op for a given the
619 /// `indexingMap` of the original operation.
620 static AffineMap
621 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
622                            const ExpansionInfo &expansionInfo) {
623   SmallVector<AffineExpr> newExprs;
624   for (AffineExpr expr : indexingMap.getResults()) {
625     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
626     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
627         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
628           return builder.getAffineDimExpr(static_cast<unsigned>(v));
629         }));
630     newExprs.append(expandedExprs.begin(), expandedExprs.end());
631   }
632   return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
633                         indexingMap.getNumSymbols(), newExprs,
634                         builder.getContext());
635 }
636 
637 /// Return the type of the operand/result to use in the expanded op given the
638 /// type in the original op.
639 static RankedTensorType getExpandedType(RankedTensorType originalType,
640                                         AffineMap indexingMap,
641                                         const ExpansionInfo &expansionInfo) {
642   SmallVector<int64_t> expandedShape;
643   for (AffineExpr expr : indexingMap.getResults()) {
644     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
645     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
646     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
647   }
648   return RankedTensorType::get(expandedShape, originalType.getElementType());
649 }
650 
651 /// Returns the reassociation maps to use in the `tensor.expand_shape`
652 /// operation to convert the operands of the original operation to operands of
653 /// the expanded operation. The same method is used to compute the
654 /// `tensor.collapse_shape` used to collapse the result of the expanded
655 /// op to get the value that can replace all uses of the results of the original
656 /// op.
657 static SmallVector<ReassociationIndices>
658 getReassociationForExpansion(AffineMap indexingMap,
659                              const ExpansionInfo &expansionInfo) {
660   SmallVector<ReassociationIndices> reassociation;
661   unsigned numReshapeDims = 0;
662   for (AffineExpr expr : indexingMap.getResults()) {
663     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
664     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
665     SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
666         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
667     reassociation.emplace_back(std::move(indices));
668     numReshapeDims += numExpandedDims;
669   }
670   return reassociation;
671 }
672 
673 /// Update the body of an expanded linalg operation having index semantics. The
674 /// indices of the original operation need to be recovered by linearizing the
675 /// indices of the correspoding dimensions of the expanded operation. For now it
676 /// is assumed that the shapes of the expanded operation needed for
677 /// linearization are static.
678 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
679                                           Location loc, Region &fusedRegion,
680                                           const ExpansionInfo &expansionInfo) {
681   // Replace the original indices by the linearization of the expanded indices.
682   for (IndexOp indexOp :
683        llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
684     ArrayRef<int64_t> expandedDims =
685         expansionInfo.getExpandedDims(indexOp.dim());
686     assert(!expandedDims.empty() && "expected valid expansion info");
687 
688     // Skip index operations that are not affected by the expansion.
689     if (expandedDims.size() == 1 &&
690         expandedDims.front() == (int64_t)indexOp.dim())
691       continue;
692 
693     // Linearize the expanded indices of the original index dimension.
694     OpBuilder::InsertionGuard guard(rewriter);
695     rewriter.setInsertionPointAfter(indexOp);
696     ArrayRef<int64_t> expandedDimsShape =
697         expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
698     SmallVector<Value> expandedIndices;
699     expandedIndices.reserve(expandedDims.size() - 1);
700     llvm::transform(
701         expandedDims.drop_front(), std::back_inserter(expandedIndices),
702         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
703     Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
704     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
705       assert(!ShapedType::isDynamic(std::get<0>(it)));
706       AffineExpr idx, acc;
707       bindDims(rewriter.getContext(), idx, acc);
708       newIndex = rewriter.create<AffineApplyOp>(
709           indexOp.getLoc(), idx + acc * std::get<0>(it),
710           ValueRange{std::get<1>(it), newIndex});
711     }
712     rewriter.replaceOp(indexOp, newIndex);
713   }
714 }
715 
716 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op
717 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
718 /// that those conditions have been satisfied.
719 static Optional<SmallVector<Value>>
720 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
721                            OpOperand *fusableOpOperand,
722                            PatternRewriter &rewriter) {
723   assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
724          "preconditions for fuse operation failed");
725   // Check if reshape is expanding or collapsing.
726   auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
727   auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
728   bool isExpanding = (expandingReshapeOp != nullptr);
729   RankedTensorType expandedType = isExpanding
730                                       ? expandingReshapeOp.getResultType()
731                                       : collapsingReshapeOp.getSrcType();
732 
733   ExpansionInfo expansionInfo;
734   if (failed(expansionInfo.compute(
735           genericOp, fusableOpOperand,
736           isExpanding ? expandingReshapeOp.getReassociationMaps()
737                       : collapsingReshapeOp.getReassociationMaps(),
738           expandedType.getShape(), rewriter)))
739     return llvm::None;
740 
741   if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
742     return llvm::None;
743 
744   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
745       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
746         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
747       }));
748 
749   SmallVector<Value> expandedOpOperands;
750   expandedOpOperands.reserve(genericOp.getNumInputs());
751   for (OpOperand *opOperand : genericOp.getInputOperands()) {
752     if (opOperand == fusableOpOperand) {
753       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
754                                                : collapsingReshapeOp.src());
755       continue;
756     }
757     if (genericOp.isInputTensor(opOperand)) {
758       AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
759       RankedTensorType expandedOperandType =
760           getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
761                           indexingMap, expansionInfo);
762       if (expandedOperandType != opOperand->get().getType()) {
763         // Reshape the operand to get the right type.
764         SmallVector<ReassociationIndices> reassociation =
765             getReassociationForExpansion(indexingMap, expansionInfo);
766         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
767             genericOp.getLoc(), expandedOperandType, opOperand->get(),
768             reassociation));
769         continue;
770       }
771     }
772     expandedOpOperands.push_back(opOperand->get());
773   }
774 
775   Location loc = genericOp.getLoc();
776   SmallVector<Value> outputs;
777   for (OpOperand *opOperand : genericOp.getOutputOperands()) {
778     AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
779     RankedTensorType expandedOutputType =
780         getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
781                         indexingMap, expansionInfo);
782     if (expandedOutputType != opOperand->get().getType()) {
783       SmallVector<ReassociationIndices> reassociation =
784           getReassociationForExpansion(indexingMap, expansionInfo);
785       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
786           genericOp.getLoc(), expandedOutputType, opOperand->get(),
787           reassociation));
788     }
789   }
790 
791   // The iterator types of the expanded op are all parallel.
792   SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
793                                        getParallelIteratorTypeName());
794 
795   TypeRange resultTypes = ValueRange(outputs).getTypes();
796   auto fusedOp =
797       rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
798                                  /*inputs=*/expandedOpOperands, outputs,
799                                  expandedOpIndexingMaps, iteratorTypes);
800   Region &fusedRegion = fusedOp->getRegion(0);
801   Region &originalRegion = genericOp->getRegion(0);
802   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
803 
804   // Update the index accesses after the expansion.
805   updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
806 
807   // Reshape the result values to their original shape if this is a collapsing
808   // reshape folded into its consumer.
809   SmallVector<Value> resultVals;
810   for (OpResult opResult : genericOp->getOpResults()) {
811     int64_t resultNumber = opResult.getResultNumber();
812     if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
813       SmallVector<ReassociationIndices> reassociation =
814           getReassociationForExpansion(
815               genericOp.getTiedIndexingMap(
816                   genericOp.getOutputOperand(resultNumber)),
817               expansionInfo);
818       resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
819           genericOp.getLoc(), opResult.getType(),
820           fusedOp->getResult(resultNumber), reassociation));
821     } else {
822       resultVals.push_back(fusedOp->getResult(resultNumber));
823     }
824   }
825   // Assuming a single result.
826   return resultVals;
827 }
828 
829 namespace {
830 
831 /// Pattern to fold tensor_expand_shape op with its consumer by using the source
832 /// of the reshape op as the operand in the consumer (instead of the result of
833 /// the tensor_collapse_shape). The corresponding index map in the consumer
834 /// needs to be modified to linearize the folded dimension.
835 ///
836 /// For example,
837 ///
838 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
839 /// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
840 ///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
841 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
842 ///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
843 ///        -> tensor<?x?x4x?xf32>
844 ///
845 /// can be folded into
846 ///
847 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
848 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
849 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
850 ///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
851 ///        -> tensor<?x?x4x?xf32>
852 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
853 struct FoldProducerReshapeOpByLinearization
854     : public OpRewritePattern<GenericOp> {
855   using OpRewritePattern<GenericOp>::OpRewritePattern;
856 
857   LogicalResult matchAndRewrite(GenericOp genericOp,
858                                 PatternRewriter &rewriter) const override {
859     if (!genericOp.hasTensorSemantics())
860       return failure();
861     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
862     for (auto en : llvm::enumerate(inputOperands)) {
863       auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
864       if (!reshapeOp)
865         continue;
866 
867       if (!isTensorReshapeOpFoldableByLinearization(
868               reshapeOp, genericOp.getTiedIndexingMap(en.value()),
869               /*asProducer =*/true) ||
870           (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
871         continue;
872 
873       // Compute the fused operands list,
874       SmallVector<Value> fusedOperands = genericOp.getInputOperands();
875       fusedOperands[en.index()] = reshapeOp.src();
876       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
877       llvm::append_range(fusedOperands, outputOperands);
878 
879       // Compute indexing_maps for the fused operation. The indexing_maps for
880       // the operands of the consumers that arent fused are the same.
881       SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
882 
883       // Accepted consumer maps are either identity or permutation.
884       auto invMap = inversePermutation(fusedIndexMaps[en.index()]);
885 
886       // Compute the indexing map to use for the result of the producer.
887       AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp);
888       // The modified map cannot have symbols.
889       if (modifiedMap.getNumSymbols())
890         return failure();
891       for (AffineExpr expr : modifiedMap.getResults()) {
892         if (!expr.isPureAffine())
893           return failure();
894       }
895       fusedIndexMaps[en.index()] = modifiedMap;
896 
897       // Further check that the resulting index maps can be fused and
898       // inverted. Without this the resultant op is not legal.
899       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
900         return rewriter.notifyMatchFailure(
901             genericOp, "fused op loop bound computation failed");
902       }
903 
904       rewriter.startRootUpdate(genericOp);
905       genericOp->setOperands(fusedOperands);
906       genericOp.indexing_mapsAttr(
907           rewriter.getAffineMapArrayAttr(fusedIndexMaps));
908       rewriter.finalizeRootUpdate(genericOp);
909       return success();
910     }
911     return failure();
912   }
913 };
914 
915 static SmallVector<ReassociationIndices>
916 getReassociationIndices(ArrayRef<AffineMap> maps) {
917   SmallVector<ReassociationIndices> reassociation;
918   for (AffineMap map : maps) {
919     ReassociationIndices indices;
920     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
921       unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
922       indices.push_back(pos);
923     }
924     reassociation.push_back(indices);
925   }
926   return reassociation;
927 }
928 
929 /// Pattern to move rank reducing reshape after an elementwise linalg generic
930 /// op. This is useful to expose more fusion opportunities between named ops and
931 /// generic ops. This can only be done if there is no broadcast or permuation
932 /// within the dimensions we need to merge.
933 ///
934 /// For example,
935 ///
936 ///  %0 = tensor.expand_shape %A [[0, 1], [2]]
937 ///      : tensor<12544x16xf32> into tensor<112x112x16xf32>
938 ///  %2 = linalg.generic {indexing_maps = [
939 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
940 ///    affine_map<(d0, d1, d2) -> (d2)>,
941 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
942 ///    ["parallel", "parallel", "parallel"]} {
943 ///  } -> tensor<112x112x16xf32>
944 ///
945 ///  into
946 ///
947 ///  %2 = linalg.generic {indexing_maps = [
948 ///    affine_map<(d0, d1) -> (d0, d1)>,
949 ///    affine_map<(d0, d1) -> (d1)>,
950 ///    affine_map<(d0, d1) -> (d0, d1)>],
951 ///    iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
952 ///    : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
953 ///  } -> tensor<12544x16xf32>
954 ///  %3 = tensor.expand_shape %2 [[0, 1], [2]]
955 ///    : tensor<12544x16xf32> into tensor<112x112x16xf32>
956 struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
957   using OpRewritePattern<GenericOp>::OpRewritePattern;
958 
959   LogicalResult matchAndRewrite(GenericOp genericOp,
960                                 PatternRewriter &rewriter) const override {
961     // Only apply to elementwise linalg on tensor.
962     if (!genericOp.hasTensorSemantics() ||
963         genericOp.getNumParallelLoops() != genericOp.getNumLoops())
964       return failure();
965     // Only support identity output maps. It could be extended to permuations if
966     // needed.
967     if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
968           return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
969         }))
970       return failure();
971     int64_t destRank = genericOp.getNumParallelLoops();
972     SmallVector<Value> newOperands = genericOp.getInputOperands();
973     tensor::ExpandShapeOp reshapeFound;
974     // 1. Look for tensor_expand_shape operands and figure out save the
975     // dimensions merged.
976     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
977     for (auto en : llvm::enumerate(inputOperands)) {
978       auto reshapeOp =
979           en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>();
980       if (!reshapeOp)
981         continue;
982       // TODO: We could support non-identity map as long as the merged
983       // dimensions are still contiguous.
984       if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
985         continue;
986       if (reshapeFound) {
987         // Only support a second reshape op if it has the same reassociate maps.
988         if (reshapeFound.getReassociationMaps() ==
989             reshapeOp.getReassociationMaps())
990           newOperands[en.index()] = reshapeOp.src();
991         continue;
992       }
993       reshapeFound = reshapeOp;
994       newOperands[en.index()] = reshapeOp.src();
995     }
996     if (!reshapeFound)
997       return failure();
998 
999     // Calculate the reassociation indices and rassociated reverse map.
1000     SmallVector<ReassociationIndices> reassociation =
1001         getReassociationIndices(reshapeFound.getReassociationMaps());
1002     SmallVector<unsigned> remap(destRank);
1003     for (auto &indices : llvm::enumerate(reassociation)) {
1004       for (int64_t index : indices.value()) {
1005         remap[index] = indices.index();
1006       }
1007     }
1008     // 2. Verify that we can merge the dimensions in the linalg and that we
1009     // don't need to create new reshapes operands. Inserting new reshape
1010     // operands would defeat the purpose of the transformation.
1011     for (auto en : llvm::enumerate(inputOperands)) {
1012       if (en.value()->get() == newOperands[en.index()]) {
1013         AffineMap map = genericOp.getTiedIndexingMap(en.value());
1014         for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
1015           if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
1016             return failure();
1017         }
1018       }
1019     }
1020 
1021     // 3. Calculate the affine map remapping and the reassociation to apply to
1022     // output tensors.
1023     SmallVector<AffineMap> newMaps;
1024     unsigned newRank = reassociation.size();
1025     for (auto map : genericOp.getIndexingMaps()) {
1026       SmallVector<AffineExpr> newExprs;
1027       for (auto expr : map.getResults()) {
1028         unsigned position = expr.template cast<AffineDimExpr>().getPosition();
1029         // Skip dimension merged except for the last of the group.
1030         if (reassociation[remap[position]].back() == position) {
1031           newExprs.push_back(
1032               getAffineDimExpr(remap[position], genericOp.getContext()));
1033         }
1034       }
1035       newMaps.push_back(
1036           AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
1037     }
1038 
1039     // 4. Reshape the output tensors.
1040     SmallVector<Value> newOutputs;
1041     SmallVector<Type> newOutputTypes;
1042     for (auto output : genericOp.outputs()) {
1043       auto newOutputType = RankedTensorType::get(
1044           reshapeFound.getSrcType().getShape(),
1045           output.getType().template cast<RankedTensorType>().getElementType());
1046       Value newOutput = rewriter.create<tensor::CollapseShapeOp>(
1047           genericOp->getLoc(), newOutputType, output, reassociation);
1048       newOutputTypes.push_back(newOutputType);
1049       newOutputs.push_back(newOutput);
1050     }
1051     // 5. Create a new generic op with lowerer rank.
1052     SmallVector<StringRef> iteratorTypes(newRank,
1053                                          getParallelIteratorTypeName());
1054     auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
1055                                             newOperands, newOutputs, newMaps,
1056                                             iteratorTypes);
1057     rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
1058                                 newOp.region().begin());
1059     // 6. Reshape the so that the type matches the uses.
1060     SmallVector<Value> newResults;
1061     for (auto result : llvm::enumerate(newOp->getResults())) {
1062       newResults.push_back(rewriter.create<tensor::ExpandShapeOp>(
1063           genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
1064           result.value(), reassociation));
1065     }
1066     rewriter.replaceOp(genericOp, newResults);
1067     return success();
1068   }
1069 };
1070 
1071 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
1072 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
1073 /// in the consumer is expanded.
1074 class FoldWithProducerReshapeOpByExpansion
1075     : public OpRewritePattern<GenericOp> {
1076 public:
1077   FoldWithProducerReshapeOpByExpansion(
1078       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1079       PatternBenefit benefit = 1)
1080       : OpRewritePattern<GenericOp>(context, benefit),
1081         controlFoldingReshapes(foldReshapes) {}
1082 
1083   LogicalResult matchAndRewrite(GenericOp genericOp,
1084                                 PatternRewriter &rewriter) const override {
1085     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1086       tensor::CollapseShapeOp reshapeOp =
1087           opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1088       if (!reshapeOp)
1089         continue;
1090       // Fold only if
1091       // - The tensor reshape op is folding.
1092       // - All constraints of fusing with reshape by expansion are met.
1093       if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
1094           (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
1095         continue;
1096 
1097       Optional<SmallVector<Value>> replacementValues =
1098           fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
1099       if (!replacementValues)
1100         return failure();
1101       rewriter.replaceOp(genericOp, replacementValues.getValue());
1102       return success();
1103     }
1104     return failure();
1105   }
1106 
1107 private:
1108   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1109 };
1110 
1111 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
1112 /// producer. The corresponding index map in the consumer needs to be modified
1113 /// to linearize the folded dimension.
1114 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
1115 struct FoldConsumerReshapeOpByLinearization
1116     : public OpRewritePattern<TensorReshapeOp> {
1117   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1118 
1119   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1120                                 PatternRewriter &rewriter) const override {
1121     GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
1122     if (!producer || !producer.hasTensorSemantics() ||
1123         producer.getNumOutputs() != 1 ||
1124         !isTensorReshapeOpFoldableByLinearization(
1125             reshapeOp,
1126             producer.getTiedIndexingMap(producer.getOutputOperand(0)),
1127             /*asProducer =*/false) ||
1128         (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
1129       return failure();
1130     // The indexing_maps for the operands of the fused operation are same as
1131     // those for the operands of the producer.
1132     SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
1133 
1134     auto invMap = inversePermutation(
1135         producer.getTiedIndexingMap(producer.getOutputOperand(0)));
1136 
1137     // Compute the indexing map to use for the operand of the producer.
1138     AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp);
1139     for (AffineExpr expr : modifiedMap.getResults()) {
1140       if (!expr.isPureAffine()) {
1141         return rewriter.notifyMatchFailure(
1142             producer, "fused op indexing map is not affine");
1143       }
1144     }
1145     fusedIndexMaps.back() = modifiedMap;
1146 
1147     // Further check that the resulting index maps can be fused and
1148     // inverted. Without this the resultant op is not legal.
1149     if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1150       return rewriter.notifyMatchFailure(
1151           producer, "fused op loop bound computation failed");
1152     }
1153 
1154     Location loc = producer.getLoc();
1155     SmallVector<Value> inputOperands = producer.getInputOperands();
1156     Value output = rewriter.create<TensorReshapeOp>(
1157         loc, producer.getOutputOperand(0)->get(),
1158         reshapeOp.getReassociationExprs());
1159     auto fusedOp = rewriter.create<GenericOp>(
1160         loc, reshapeOp.getResultType(),
1161         /*inputs=*/inputOperands,
1162         // TODO: handle outputs.
1163         /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1164         producer.iterator_types(),
1165         /*doc=*/nullptr,
1166         /*library_call=*/nullptr);
1167     auto &fusedRegion = fusedOp->getRegion(0);
1168     rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
1169                                fusedRegion.begin());
1170     rewriter.replaceOp(reshapeOp, fusedOp->getResults());
1171     return success();
1172   }
1173 };
1174 
1175 /// Pattern to fold a tensor_expand_shape op with its producer generic op
1176 /// by expanding the dimensionality of the loop in the producer op.
1177 struct FoldReshapeWithGenericOpByExpansion
1178     : public OpRewritePattern<tensor::ExpandShapeOp> {
1179 
1180   FoldReshapeWithGenericOpByExpansion(
1181       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1182       PatternBenefit benefit = 1)
1183       : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1184         controlFoldingReshapes(foldReshapes) {}
1185 
1186   LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1187                                 PatternRewriter &rewriter) const override {
1188     // Fold only if all constraints of fusing with reshape by expansion are met.
1189     GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
1190     if (!producer || producer.getNumOutputs() != 1 ||
1191         !isFusableWithReshapeByDimExpansion(producer,
1192                                             producer.getOutputOperand(0)) ||
1193         !controlFoldingReshapes(producer->getResult(0),
1194                                 reshapeOp->getOpOperand(0)))
1195       return failure();
1196     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
1197         producer, reshapeOp, producer.getOutputOperand(0), rewriter);
1198     if (!replacementValues)
1199       return failure();
1200     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
1201     return success();
1202   }
1203 
1204 private:
1205   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1206 };
1207 
1208 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1209 /// handle cases where the constant is not single-valued.
1210 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1211 public:
1212   FoldScalarOrSplatConstant(MLIRContext *context,
1213                             ControlElementwiseOpsFusionFn &fun,
1214                             PatternBenefit benefit = 1)
1215       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
1216 
1217   LogicalResult matchAndRewrite(GenericOp genericOp,
1218                                 PatternRewriter &rewriter) const override {
1219     if (!genericOp.hasTensorSemantics())
1220       return failure();
1221     for (OpOperand *opOperand : genericOp.getInputOperands()) {
1222       Operation *def = opOperand->get().getDefiningOp();
1223       Attribute constantAttr;
1224       auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1225         {
1226           DenseElementsAttr splatAttr;
1227           if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1228               splatAttr.isSplat() &&
1229               splatAttr.getType().getElementType().isIntOrFloat()) {
1230             constantAttr = splatAttr.getSplatValue<Attribute>();
1231             return true;
1232           }
1233         }
1234         {
1235           IntegerAttr intAttr;
1236           if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1237             constantAttr = intAttr;
1238             return true;
1239           }
1240         }
1241         {
1242           FloatAttr floatAttr;
1243           if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1244             constantAttr = floatAttr;
1245             return true;
1246           }
1247         }
1248         return false;
1249       };
1250 
1251       auto resultValue = opOperand->get().dyn_cast<OpResult>();
1252       if (!def || !resultValue || !isScalarOrSplatConstantOp(def) ||
1253           !controlFn(resultValue, *opOperand))
1254         continue;
1255 
1256       // The operands and the indexing_maps of the fused operation the same as
1257       // the operands and indexing_maps of the generic operations with the
1258       // values at the constant index dropped.
1259       SmallVector<AffineMap> fusedIndexMaps;
1260       SmallVector<Value> fusedOperands;
1261       SmallVector<Location> fusedLocs{genericOp.getLoc()};
1262       fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1263       fusedOperands.reserve(genericOp.getNumInputs());
1264       fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
1265       for (OpOperand *inputOperand : genericOp.getInputOperands()) {
1266         if (inputOperand == opOperand)
1267           continue;
1268         Value inputValue = inputOperand->get();
1269         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1270         fusedOperands.push_back(inputValue);
1271         fusedLocs.push_back(inputValue.getLoc());
1272       }
1273       for (OpOperand *outputOperand : genericOp.getOutputOperands())
1274         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1275 
1276       // Check if the operation shapes to loops map is computable.
1277       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1278         return rewriter.notifyMatchFailure(
1279             genericOp, "fused op loop bound computation failed");
1280       }
1281 
1282       // Create a constant scalar value from the splat constant.
1283       Value scalarConstant = rewriter.create<arith::ConstantOp>(
1284           def->getLoc(), constantAttr, constantAttr.getType());
1285 
1286       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1287       auto fusedOp = rewriter.create<GenericOp>(
1288           rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1289           /*inputs=*/fusedOperands,
1290           /*outputs=*/outputOperands,
1291           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1292           genericOp.iterator_types(),
1293           /*doc=*/nullptr,
1294           /*library_call=*/nullptr);
1295 
1296       // Map the block argument corresponding to the replaced argument with the
1297       // scalar constant.
1298       Region &region = genericOp->getRegion(0);
1299       Block &entryBlock = *region.begin();
1300       BlockAndValueMapping mapping;
1301       mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1302                   scalarConstant);
1303       Region &fusedRegion = fusedOp->getRegion(0);
1304       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1305                                  mapping);
1306       rewriter.replaceOp(genericOp, fusedOp->getResults());
1307       return success();
1308     }
1309     return failure();
1310   }
1311 
1312 private:
1313   ControlElementwiseOpsFusionFn controlFn;
1314 };
1315 
1316 /// Base class for constant folding linalg.generic ops with N inputs, 1 output,
1317 /// and permutation indexing maps.
1318 ///
1319 /// `ConcreteType` should provide methods with signatures
1320 ///
1321 /// ```c++
1322 ///   bool matchIndexingMaps(GenericOp genericOp) const;
1323 ///   RegionComputationFn getRegionComputeFn(GenericOp) const;
1324 /// ```
1325 ///
1326 /// The latter inspects the region and returns the computation inside as a
1327 /// functor. The functor will be invoked with constant elements for all inputs
1328 /// and should return the corresponding computea constant element for output.
1329 template <typename ConcreteType>
1330 class FoldConstantBase : public OpRewritePattern<GenericOp> {
1331 public:
1332   struct APIntOrFloat {
1333     Optional<APInt> apInt;
1334     Optional<APFloat> apFloat;
1335   };
1336   struct APIntOrFloatArray {
1337     SmallVector<APInt> apInts;
1338     SmallVector<APFloat> apFloats;
1339   };
1340   using RegionComputationFn =
1341       std::function<APIntOrFloat(const APIntOrFloatArray &)>;
1342 
1343   FoldConstantBase(MLIRContext *context,
1344                    const ControlElementwiseOpsFusionFn &controlFn,
1345                    PatternBenefit benefit = 1)
1346       : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
1347 
1348   LogicalResult matchAndRewrite(GenericOp genericOp,
1349                                 PatternRewriter &rewriter) const override {
1350     if (genericOp.hasBufferSemantics())
1351       return failure();
1352 
1353     // Only support ops generating one output for now.
1354     if (genericOp.getNumOutputs() != 1)
1355       return failure();
1356 
1357     auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
1358     // Require the output types to be static give we are generating constants.
1359     if (!outputType || !outputType.hasStaticShape())
1360       return failure();
1361 
1362     if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
1363           return operand->get().getType().isa<ShapedType>();
1364         }))
1365       return failure();
1366 
1367     // Make sure all element types are the same.
1368     auto getOperandElementType = [](OpOperand *operand) {
1369       return operand->get().getType().cast<ShapedType>().getElementType();
1370     };
1371     if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
1372                                         getOperandElementType)))
1373       return failure();
1374 
1375     // We can only handle the case where we have int/float elements.
1376     auto elementType = outputType.getElementType();
1377     if (!elementType.isIntOrFloat())
1378       return failure();
1379 
1380     // Require all indexing maps to be permutations for now. This is common and
1381     // it simplifies input/output access greatly: we can do the data shuffling
1382     // entirely in the compiler, without needing to turn all indices into
1383     // Values, and then do affine apply on them, and then match back the
1384     // constant again.
1385     if (!llvm::all_of(genericOp.getIndexingMaps(),
1386                       [](AffineMap map) { return map.isPermutation(); }))
1387       return failure();
1388 
1389     for (OpOperand *operand : genericOp.getOutputOperands()) {
1390       if (genericOp.payloadUsesValueFromOperand(operand))
1391         return failure();
1392     }
1393 
1394     // Further check the indexing maps are okay for the ConcreteType.
1395     if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
1396       return failure();
1397 
1398     // Defer to the concrete type to check the region and discover the
1399     // computation inside.
1400     RegionComputationFn computeFn =
1401         static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
1402     if (!computeFn)
1403       return failure();
1404 
1405     // All inputs should be constants.
1406     int numInputs = genericOp.getNumInputs();
1407     SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
1408     for (auto operand : llvm::enumerate(genericOp.getInputOperands())) {
1409       if (!matchPattern(operand.value()->get(),
1410                         m_Constant(&inputValues[operand.index()])))
1411         return failure();
1412     }
1413 
1414     // Identified this as a potential candidate for folding. Now check the
1415     // policy to see whether we are allowed to proceed.
1416     for (int i = 0; i < numInputs; ++i) {
1417       OpOperand *consumer = genericOp.getInputOperand(i);
1418       OpResult producer = consumer->get().cast<OpResult>();
1419       if (!controlFn(producer, *consumer))
1420         return failure();
1421     }
1422 
1423     auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
1424     SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
1425     int64_t numElements = outputType.getNumElements();
1426 
1427     // Use APInt/APFloat instead of Attribute here for constructing the output.
1428     // This helps to avoid blowing up compiler memory usage: Attributes would
1429     // unify the following cases but they have lifetime as the MLIRContext.
1430     SmallVector<APInt> intOutputValues;
1431     SmallVector<APFloat> fpOutputValues;
1432     if (elementType.template isa<FloatType>())
1433       fpOutputValues.resize(numElements, APFloat(0.f));
1434     else
1435       intOutputValues.resize(numElements);
1436 
1437     // Return the constant dim positions from the given permutation map.
1438     auto getDimPositions = [](AffineMap map) {
1439       SmallVector<unsigned> dims;
1440       dims.reserve(map.getNumResults());
1441       for (AffineExpr result : map.getResults()) {
1442         dims.push_back(result.cast<AffineDimExpr>().getPosition());
1443       }
1444       return dims;
1445     };
1446 
1447     SmallVector<SmallVector<unsigned>> inputDims;
1448     for (int i = 0; i < numInputs; ++i)
1449       inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i]));
1450     auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
1451     auto outputShape = outputType.getShape();
1452 
1453     // Allocate small vectors for index delinearization. Initial values do not
1454     // matter here as they will be overwritten later.
1455     SmallVector<uint64_t> indices(loopBounds.size(), 0);
1456     SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
1457     SmallVector<SmallVector<uint64_t>> srcIndices(
1458         numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
1459     SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
1460     uint64_t dstLinearIndex = 0;
1461 
1462     // Allocate spaces for compute function inputs. Initial values do not matter
1463     // here as they will be overwritten later.
1464     APIntOrFloatArray computeFnInputs;
1465 
1466     auto inputShapes = llvm::to_vector<4>(
1467         llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
1468           return operand->get().getType().cast<ShapedType>().getShape();
1469         }));
1470 
1471     // Given a `linearIndex`, remap it to a linear index to access linalg op
1472     // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
1473     // `srcLinearIndices`, `dstLinearIndex` in place.
1474     auto computeRemappedLinearIndex = [&](int linearIndex) {
1475       int totalCount = linearIndex;
1476       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
1477         indices[dim] = totalCount % loopBounds[dim];
1478         totalCount /= loopBounds[dim];
1479       }
1480 
1481       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
1482         for (int i = 0; i < numInputs; ++i)
1483           srcIndices[i][dim] = indices[inputDims[i][dim]];
1484         dstIndices[dim] = indices[outputDims[dim]];
1485       }
1486 
1487       dstLinearIndex = dstIndices.front();
1488       for (int i = 0; i < numInputs; ++i)
1489         srcLinearIndices[i] = srcIndices[i].front();
1490 
1491       for (int dim = 1; dim < outputType.getRank(); ++dim) {
1492         dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
1493         for (int i = 0; i < numInputs; ++i)
1494           srcLinearIndices[i] =
1495               srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
1496       }
1497     };
1498 
1499     bool isFloat = elementType.isa<FloatType>();
1500     if (isFloat) {
1501       SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
1502       for (int i = 0; i < numInputs; ++i)
1503         inFpRanges.push_back(inputValues[i].getValues<APFloat>());
1504 
1505       computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
1506 
1507       // Transpose the input constant. Because we don't know its rank in
1508       // advance, we need to loop over the range [0, element count) and
1509       // delinearize the index.
1510       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
1511         computeRemappedLinearIndex(linearIndex);
1512 
1513         // Collect constant elements for all inputs at this loop iteration.
1514         for (int i = 0; i < numInputs; ++i)
1515           computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
1516 
1517         // Invoke the computation to get the corresponding constant output
1518         // element.
1519         fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
1520       }
1521     } else {
1522       SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
1523       for (int i = 0; i < numInputs; ++i)
1524         inIntRanges.push_back(inputValues[i].getValues<APInt>());
1525 
1526       computeFnInputs.apInts.resize(numInputs);
1527 
1528       // Transpose the input constant. Because we don't know its rank in
1529       // advance, we need to loop over the range [0, element count) and
1530       // delinearize the index.
1531       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
1532         computeRemappedLinearIndex(linearIndex);
1533 
1534         // Collect constant elements for all inputs at this loop iteration.
1535         for (int i = 0; i < numInputs; ++i)
1536           computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
1537 
1538         // Invoke the computation to get the corresponding constant output
1539         // element.
1540         intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
1541       }
1542     }
1543 
1544     DenseElementsAttr outputAttr =
1545         isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
1546                 : DenseElementsAttr::get(outputType, intOutputValues);
1547 
1548     rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr);
1549     return success();
1550   }
1551 
1552 private:
1553   ControlElementwiseOpsFusionFn controlFn;
1554 };
1555 
1556 // Folds linalg.generic ops that are actually transposes on constant values.
1557 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
1558   using FoldConstantBase::FoldConstantBase;
1559 
1560   bool matchIndexingMaps(GenericOp genericOp) const {
1561     // We should have one input and one output.
1562     return genericOp.getIndexingMaps().size() == 2;
1563   }
1564 
1565   RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
1566     // Make sure the region only contains a yield op.
1567     Block &body = genericOp.region().front();
1568     if (!llvm::hasSingleElement(body))
1569       return nullptr;
1570     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1571     if (!yieldOp)
1572       return nullptr;
1573 
1574     // The yield op should return the block argument corresponds to the input.
1575     for (Value yieldVal : yieldOp.values()) {
1576       auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
1577       if (!yieldArg || yieldArg.getOwner() != &body)
1578         return nullptr;
1579       if (yieldArg.getArgNumber() != 0)
1580         return nullptr;
1581     }
1582 
1583     // No computation; just return the orginal value.
1584     return [](const APIntOrFloatArray &inputs) {
1585       if (inputs.apFloats.empty())
1586         return APIntOrFloat{inputs.apInts.front(), llvm::None};
1587       return APIntOrFloat{llvm::None, inputs.apFloats.front()};
1588     };
1589   }
1590 
1591   ControlElementwiseOpsFusionFn controlFn;
1592 };
1593 
1594 } // namespace
1595 
1596 static Optional<SmallVector<Value>>
1597 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
1598                    GenericOp producer,
1599                    const ControlElementwiseOpsFusionFn &controlFn) {
1600   if (producer->getNumResults() != 1)
1601     return llvm::None;
1602 
1603   return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
1604                                 rewriter);
1605 }
1606 
1607 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
1608                                       OpOperand &consumer) {
1609   if (auto producerCollapseOp =
1610           dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
1611     return !isUnitDimExpansionOnly(producerCollapseOp);
1612   }
1613   if (auto consumerExpandOp =
1614           dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
1615     return !isUnitDimExpansionOnly(consumerExpandOp);
1616   }
1617   return true;
1618 }
1619 
1620 namespace {
1621 /// Patterns to fuse a generic op, with the producer of its operands.
1622 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
1623 public:
1624   FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
1625                      PatternBenefit benefit = 1)
1626       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
1627 
1628   LogicalResult matchAndRewrite(GenericOp genericOp,
1629                                 PatternRewriter &rewriter) const override {
1630     // Find the first operand that is defined by another generic op on tensors.
1631     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
1632       auto producer =
1633           dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
1634       if (!producer || !producer.hasTensorSemantics())
1635         continue;
1636       Optional<SmallVector<Value>> fusedOpResults =
1637           fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
1638       if (fusedOpResults) {
1639         rewriter.replaceOp(genericOp, *fusedOpResults);
1640         return success();
1641       }
1642     }
1643     return failure();
1644   }
1645 
1646 private:
1647   ControlElementwiseOpsFusionFn controlFn;
1648 };
1649 
1650 /// Pass that fuses generic ops on tensors. Used only for testing.
1651 struct LinalgElementwiseOpFusionPass
1652     : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
1653   void runOnOperation() override {
1654     Operation *op = getOperation();
1655     RewritePatternSet patterns(op->getContext());
1656     ControlElementwiseOpsFusionFn allowFoldingFn =
1657         [](const OpResult &producer, const OpOperand &consumer) {
1658           return true;
1659         };
1660     populateElementwiseOpsFusionPatterns(
1661         patterns,
1662         LinalgElementwiseFusionOptions().setControlFoldingReshapes(
1663             allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
1664 
1665     // Use TopDownTraversal for compile time reasons
1666     GreedyRewriteConfig grc;
1667     grc.useTopDownTraversal = true;
1668     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
1669                                        grc);
1670   }
1671 };
1672 
1673 /// Pass to test folding of reshape ops with generic ops by linearization.
1674 struct FoldReshapeOpsByLinearizationPass
1675     : public LinalgFoldReshapeOpsByLinearizationBase<
1676           FoldReshapeOpsByLinearizationPass> {
1677   void runOnOperation() override {
1678     Operation *op = getOperation();
1679     RewritePatternSet patterns(op->getContext());
1680     populateFoldReshapeOpsByLinearizationPatterns(patterns);
1681     if (allowFoldingUnitDimReshapes) {
1682       populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
1683     }
1684     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
1685   }
1686 };
1687 
1688 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
1689 /// the value of the `outs` operand is not used within the op.  This is only
1690 /// implemented for `linalg.generic` operations for now, but should hold for all
1691 /// linalg structured ops.
1692 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1693   using OpRewritePattern<GenericOp>::OpRewritePattern;
1694 
1695   LogicalResult matchAndRewrite(GenericOp op,
1696                                 PatternRewriter &rewriter) const override {
1697     rewriter.startRootUpdate(op);
1698     bool modifiedOutput = false;
1699     Location loc = op.getLoc();
1700     for (OpOperand *opOperand : op.getOutputOperands()) {
1701       if (!op.payloadUsesValueFromOperand(opOperand)) {
1702         Value operandVal = opOperand->get();
1703         auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
1704         if (!operandType)
1705           continue;
1706 
1707         // If outs is already an `init_tensor` operation, nothing to do.
1708         auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
1709         if (definingOp)
1710           continue;
1711         modifiedOutput = true;
1712         SmallVector<Value> dynamicDims;
1713         for (auto dim : llvm::enumerate(operandType.getShape())) {
1714           if (dim.value() != ShapedType::kDynamicSize)
1715             continue;
1716           dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
1717               loc, operandVal, dim.index()));
1718         }
1719         Value initTensor = rewriter.create<InitTensorOp>(
1720             loc, dynamicDims, operandType.getShape(),
1721             operandType.getElementType());
1722         op->setOperand(opOperand->getOperandNumber(), initTensor);
1723       }
1724     }
1725     if (!modifiedOutput) {
1726       rewriter.cancelRootUpdate(op);
1727       return failure();
1728     }
1729     rewriter.finalizeRootUpdate(op);
1730     return success();
1731   }
1732 };
1733 
1734 } // namespace
1735 
1736 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
1737     RewritePatternSet &patterns) {
1738   patterns
1739       .add<FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
1740            FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
1741            FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
1742            FoldConsumerReshapeOpByLinearization<false, tensor::ExpandShapeOp>>(
1743           patterns.getContext());
1744 }
1745 
1746 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
1747     RewritePatternSet &patterns) {
1748   patterns
1749       .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
1750            FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>,
1751            FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
1752            FoldConsumerReshapeOpByLinearization<true, tensor::ExpandShapeOp>>(
1753           patterns.getContext());
1754 }
1755 
1756 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
1757     RewritePatternSet &patterns,
1758     ControlElementwiseOpsFusionFn controlFoldingReshapes) {
1759   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
1760                                                     controlFoldingReshapes);
1761   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
1762                                                      controlFoldingReshapes);
1763 }
1764 
1765 void mlir::linalg::populateElementwiseOpsFusionPatterns(
1766     RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
1767   auto *context = patterns.getContext();
1768   patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
1769                FoldConstantTranspose>(context,
1770                                       options.controlElementwiseOpsFusionFn);
1771   patterns.add<RemoveOutsDependency>(context);
1772   populateFoldReshapeOpsByExpansionPatterns(patterns,
1773                                             options.controlFoldingReshapesFn);
1774   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1775   GenericOp::getCanonicalizationPatterns(patterns, context);
1776   tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1777   tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1778   context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1779       patterns);
1780 }
1781 
1782 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
1783   auto *context = patterns.getContext();
1784   patterns.add<PushExpandingReshape>(context);
1785 }
1786 
1787 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
1788   return std::make_unique<LinalgElementwiseOpFusionPass>();
1789 }
1790 
1791 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
1792   return std::make_unique<FoldReshapeOpsByLinearizationPass>();
1793 }
1794