1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include <utility>
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Support/LLVM.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 
30 //===---------------------------------------------------------------------===//
31 // Methods and patterns that fuse elementwise `linalg.generic` operations.
32 //===---------------------------------------------------------------------===//
33 
34 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
35 /// the `producer` to use in the fused operation given the indexing map of the
36 /// result of the producer in the consumer.
37 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
38     OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
39     AffineMap fusedConsumerArgIndexMap) {
40   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
41   // from consumer loop -> consumer arg tensor index/producer result tensor
42   // index. The fused loop is same as the consumer loop. For each producer arg
43   // the indexing map to be computed is a map from consumer loop -> producer
44   // arg tensor index.
45   // producerResultIndexMap is a map from producer loop -> tensor index.
46   // Compute the inverse to get map from tensor index -> producer loop.
47   // The inverse is a map from producer result tensor index -> producer loop.
48   AffineMap invProducerResultIndexMap =
49       inversePermutation(producerResultIndexMap);
50   assert(invProducerResultIndexMap &&
51          "expected producer result indexig map to be invertible");
52 
53   LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
54   // argMap is a map from producer loop -> producer arg tensor index.
55   AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
56 
57   // Compose argMap with invProducerResultIndexMap to get a map from
58   // producer result tensor index -> producer arg tensor index.
59   AffineMap t1 = argMap.compose(invProducerResultIndexMap);
60 
61   // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
62   // consumer loop/ fused loop -> producer arg tensor index.
63   return t1.compose(fusedConsumerArgIndexMap);
64 }
65 
66 /// Conditions for elementwise fusion of generic operations.
67 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
68                                      OpOperand *consumerOpOperand) {
69   // Producer and consumer must have tensor semantics.
70   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
71     return false;
72 
73   // Verify that
74   // - the producer has all "parallel" iterator type.
75   if (producer.getNumParallelLoops() != producer.getNumLoops())
76     return false;
77 
78   // Only allow fusing the producer of an input operand for now.
79   // TODO: allow fusing the producer of an output operand.
80   if (!consumer.isInputTensor(consumerOpOperand))
81     return false;
82 
83   // Get the consumer index map. The number of results of the consumer index
84   // map must match the number of loops of the producer.
85   AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
86   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
87     return false;
88 
89   // Currently support only operations with single result.
90   if (producer.getNumOutputs() != 1)
91     return false;
92 
93   // Finally the index_map for the result must be invertible. For now just
94   // verify it is a permutation.
95   AffineMap producerResultIndexMap =
96       producer.getTiedIndexingMap(producer.getOutputOperand(0));
97   if (!producerResultIndexMap.isPermutation())
98     return false;
99 
100   // Ensure that the fusion does not remove size information required to
101   // get the loop bounds. For non-reduction generics, this is trivially the
102   // case due to the output operand. For reductions, we need to check that after
103   // the fusion, each loop dimension has at least one input that defines it.
104   if ((consumer.getNumReductionLoops())) {
105     BitVector coveredDims(consumer.getNumLoops(), false);
106 
107     auto addToCoveredDims = [&](AffineMap map) {
108       for (auto result : map.getResults())
109         if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
110           coveredDims[dimExpr.getPosition()] = true;
111     };
112 
113     for (auto pair :
114          llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) {
115       Value operand = std::get<0>(pair);
116       if (operand == consumerOpOperand->get())
117         continue;
118       AffineMap operandMap = std::get<1>(pair);
119       addToCoveredDims(operandMap);
120     }
121 
122     for (OpOperand *operand : producer.getInputOperands()) {
123       AffineMap newIndexingMap =
124           getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
125               operand, producerResultIndexMap, consumerIndexMap);
126       addToCoveredDims(newIndexingMap);
127     }
128     if (!coveredDims.all())
129       return false;
130   }
131 
132   return true;
133 }
134 
135 /// Generate the region of the fused tensor operation. The region of the fused
136 /// op must be empty.
137 static void
138 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
139                                  AffineMap consumerToProducerLoopsMap,
140                                  OpOperand *consumerOpOperand,
141                                  unsigned nloops) {
142   auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
143   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
144   // Build the region of the fused op.
145   Block &producerBlock = producer->getRegion(0).front();
146   Block &consumerBlock = consumer->getRegion(0).front();
147   Block *fusedBlock = new Block();
148   fusedOp.region().push_back(fusedBlock);
149   BlockAndValueMapping mapper;
150   OpBuilder::InsertionGuard guard(rewriter);
151   rewriter.setInsertionPointToStart(fusedBlock);
152 
153   // 2. Add an index operation for every fused loop dimension and use the
154   // `consumerToProducerLoopsMap` to map the producer indices.
155   if (producer.hasIndexSemantics()) {
156     // Add an index operation for every fused loop dimension.
157     unsigned numFusedOpLoops =
158         std::max(producer.getNumLoops(), consumer.getNumLoops());
159     SmallVector<Value> fusedIndices;
160     fusedIndices.reserve(numFusedOpLoops);
161     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
162                     std::back_inserter(fusedIndices), [&](uint64_t dim) {
163                       return rewriter.create<IndexOp>(producer.getLoc(), dim);
164                     });
165     for (IndexOp indexOp :
166          llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
167       Value newIndex = rewriter.create<mlir::AffineApplyOp>(
168           producer.getLoc(),
169           consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
170       mapper.map(indexOp.getResult(), newIndex);
171     }
172   }
173   // TODO: allow fusing the producer of an output operand.
174   assert(consumer.isInputTensor(consumerOpOperand) &&
175          "expected producer of input operand");
176   // 3. Consumer input operands up to consumerIdx (exclusive).
177   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
178            consumerOpOperand->getOperandNumber())) // input assumption.
179     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
180 
181   // Replacing consumerIdx requires getting the cloned, yielded, value from
182   // the (cloned) producer block. This happens in step 9.
183 
184   // 4. Splice in producer's input operands.
185   for (BlockArgument bbArg :
186        producerBlock.getArguments().take_front(producer.getNumInputs()))
187     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
188 
189   // 4.b. Producer output operand/map that is fused needs to be mapped to the
190   // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
191   assert(producer->getNumResults() == 1 && "expected single result producer");
192   if (producer.isInitTensor(producer.getOutputOperand(0))) {
193     BlockArgument bbArg = producerBlock.getArguments()
194                               .drop_front(producer.getNumInputs())
195                               // TODO: bbArg index of
196                               .front();
197     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
198   }
199   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
200   for (BlockArgument bbArg :
201        consumerBlock.getArguments()
202            .take_front(consumer.getNumInputs())
203            .drop_front(consumerOpOperand->getOperandNumber() + 1))
204     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
205   // 6. All of consumer's output operands.
206   for (BlockArgument bbArg :
207        consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
208     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
209   // 7. All of producer's output operands except the one fused.
210   // TODO: allow fusion of multi-result producers.
211   assert(producer->getNumResults() == 1 && "expected single result producer");
212 
213   // 8. Clone all producer operations except for the yield and index operations
214   // to the fused operation.
215   for (auto &op : producerBlock.without_terminator()) {
216     if (!isa<IndexOp>(op))
217       rewriter.clone(op, mapper);
218   }
219   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
220   // forward the yield operand.
221   auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
222   // TODO: allow fusion of multi-result producers.
223   assert(producer->getNumResults() == 1 && "expected single result producer");
224   unsigned producerResultNumber = 0;
225   Value replacement =
226       mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
227   // Sanity checks, if replacement is not already in the mapper then it must be
228   // produced outside.
229   if (replacement == yieldOp.getOperand(producerResultNumber)) {
230     if (auto bb = replacement.dyn_cast<BlockArgument>())
231       assert(bb.getOwner() != &producerBlock &&
232              "yielded block argument must have been mapped");
233     else
234       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
235              "yielded value must have been mapped");
236   }
237   mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
238              replacement);
239   // 10. Clone operations from the consumer to the fused op.
240   for (auto &op : consumerBlock.getOperations())
241     rewriter.clone(op, mapper);
242 
243   // Sanity checks.
244   assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
245          "Ill-formed GenericOp region");
246 }
247 
248 static Optional<SmallVector<Value>>
249 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
250                        const ControlElementwiseOpsFusionFn &controlFn,
251                        PatternRewriter &rewriter) {
252   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
253   if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
254       !controlFn(producer->getResult(0), *consumerOpOperand))
255     return llvm::None;
256 
257   // TODO: allow fusing the producer of an output operand.
258   assert(consumer.isInputTensor(consumerOpOperand) &&
259          "expected producer of input operand");
260 
261   // Compute the fused operands list and indexing maps.
262   SmallVector<Value> fusedOperands;
263   SmallVector<AffineMap> fusedIndexMaps;
264   fusedOperands.reserve(producer->getNumOperands() +
265                         consumer->getNumOperands());
266   fusedIndexMaps.reserve(producer->getNumOperands() +
267                          consumer->getNumOperands());
268   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
269   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
270   SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
271   SmallVector<OpOperand *>::iterator it =
272       llvm::find(consumerInputs, consumerOpOperand);
273   assert(it != consumerInputs.end() && "expected to find the consumer operand");
274   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
275     fusedOperands.push_back(opOperand->get());
276     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
277   }
278   // 4. Splice in producer's input operands/maps.
279   assert(producer->getNumResults() == 1 && "expected single result producer");
280   AffineMap producerResultIndexMap =
281       producer.getTiedIndexingMap(producer.getOutputOperand(0));
282   for (OpOperand *opOperand : producer.getInputOperands()) {
283     fusedOperands.push_back(opOperand->get());
284     // Compute indexing maps for the producer args in the fused operation.
285     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
286         opOperand, producerResultIndexMap,
287         consumer.getTiedIndexingMap(consumerOpOperand));
288     fusedIndexMaps.push_back(map);
289   }
290   // 4.b. Producer output operand/map that is fused needs to be passed if it is
291   // an "initTensor" (i.e. its value is actually read).
292   assert(producer->getNumResults() == 1 && "expected single result producer");
293   if (producer.isInitTensor(producer.getOutputOperand(0))) {
294     fusedOperands.push_back(producer.getOutputOperand(0)->get());
295     // Compute indexing maps for the producer args in the fused operation.
296     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
297         producer.getOutputOperand(0), producerResultIndexMap,
298         consumer.getTiedIndexingMap(consumerOpOperand));
299     fusedIndexMaps.push_back(map);
300   }
301   // 5. Remaining consumer's input operands/maps (drop past index
302   // `consumerIdx`).
303   for (OpOperand *opOperand :
304        llvm::make_range(std::next(it), consumerInputs.end())) {
305     fusedOperands.push_back(opOperand->get());
306     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
307   }
308   // 6. All of consumer's output operands (skip operands: added by the builder).
309   for (OpOperand *opOperand : consumer.getOutputOperands())
310     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
311   // 7. All of producer's output operands/maps except the one fused.
312   // TODO: allow fusion of multi-result producers.
313   assert(producer->getNumResults() == 1 && "expected single result producer");
314 
315   // Generate the fused op.
316   SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
317   auto fusedOp = rewriter.create<GenericOp>(
318       consumer.getLoc(), consumer->getResultTypes(),
319       /*inputs=*/fusedOperands,
320       // TODO: handle outputs.
321       consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
322       consumer.iterator_types(),
323       /*doc=*/nullptr,
324       /*library_call=*/nullptr);
325   if (!fusedOp.getShapesToLoopsMap()) {
326     // Fused op has invalid indexing maps. Typically this means something is off
327     // in the input, but going ahead here would result in verification errors.
328     // So cleanup and abort.
329     rewriter.eraseOp(fusedOp);
330     return llvm::None;
331   }
332 
333   // Construct an AffineMap from consumer loops to producer loops.
334   // consumer loop -> tensor index
335   AffineMap consumerResultIndexMap =
336       consumer.getTiedIndexingMap(consumerOpOperand);
337   // tensor index -> producer loop
338   AffineMap invProducerResultIndexMap =
339       inversePermutation(producerResultIndexMap);
340   assert(invProducerResultIndexMap &&
341          "expected producer result indexig map to be invertible");
342   // consumer loop -> producer loop
343   AffineMap consumerToProducerLoopsMap =
344       invProducerResultIndexMap.compose(consumerResultIndexMap);
345 
346   generateFusedElementwiseOpRegion(rewriter, fusedOp,
347                                    consumerToProducerLoopsMap,
348                                    consumerOpOperand, consumer.getNumLoops());
349   return SmallVector<Value>(fusedOp->getResults());
350 }
351 
352 static Optional<SmallVector<Value>>
353 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
354                    GenericOp producer,
355                    const ControlElementwiseOpsFusionFn &controlFn) {
356   if (producer->getNumResults() != 1)
357     return llvm::None;
358 
359   return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
360                                 rewriter);
361 }
362 
363 namespace {
364 /// Patterns to fuse a generic op, with the producer of its operands.
365 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
366 public:
367   FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
368                      PatternBenefit benefit = 1)
369       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
370 
371   LogicalResult matchAndRewrite(GenericOp genericOp,
372                                 PatternRewriter &rewriter) const override {
373     // Find the first operand that is defined by another generic op on tensors.
374     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
375       auto producer =
376           dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
377       if (!producer || !producer.hasTensorSemantics())
378         continue;
379       Optional<SmallVector<Value>> fusedOpResults =
380           fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
381       if (fusedOpResults) {
382         rewriter.replaceOp(genericOp, *fusedOpResults);
383         return success();
384       }
385     }
386     return failure();
387   }
388 
389 private:
390   ControlElementwiseOpsFusionFn controlFn;
391 };
392 } // namespace
393 
394 //===---------------------------------------------------------------------===//
395 // Methods and patterns that fuse reshape ops with elementwise operations by
396 // linearization of indexing maps.
397 //===---------------------------------------------------------------------===//
398 
399 // TODO(ravishankarm): The indexing maps
400 // these produce in the general case are detrimental to transformations.
401 // These patterns are on deprecation path in favor of using fusion by
402 // collapsing, which covers the only legitimate use case of this pattern of
403 // folding unit-extent dims.
404 
405 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
406 /// provided, given the shape of the source tensor that corresponds to the
407 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
408 /// are "row-major" ordered logically.
409 ///
410 /// For example:
411 ///
412 /// %0 = op ... : tensor<?x?x4x5xf32>
413 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
414 ///
415 /// and reshape:
416 /// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] :
417 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
418 ///
419 /// would be rewritten into:
420 /// %0 = op ... : tensor<?x?x4x5xf32>
421 /// with output index_map
422 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
423 template <typename TensorReshapeOp>
424 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
425                                         TensorReshapeOp reshapeOp) {
426   constexpr bool isExpanding =
427       std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
428   ArrayRef<int64_t> sourceShape =
429       (isExpanding ? reshapeOp.getResultType().getShape()
430                    : reshapeOp.getSrcType().getShape());
431   SmallVector<AffineExpr> resultExprs;
432   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
433   MLIRContext *context = sourceMap.getContext();
434 
435   // Compute the result exprs based on the reassociation maps.
436   for (auto &indices : reshapeOp.getReassociationIndices()) {
437     // Assume that they are in-order and contiguous (already checked in
438     // verifier).
439     assert(!indices.empty());
440     SmallVector<int64_t> sizes;
441     SmallVector<AffineExpr> dimExprs;
442     for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
443                              sourceExprs.slice(indices[0], indices.size()))) {
444       if (std::get<0>(en) == 1)
445         continue;
446       sizes.push_back(std::get<0>(en));
447       dimExprs.push_back(std::get<1>(en));
448     }
449     AffineExpr linearizedExpr =
450         makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
451     resultExprs.push_back(linearizedExpr);
452   }
453   // The new affine map cannot drop unused dimension but some new symbols may
454   // have been added. Create a map with at least as many dimensions/symbols as
455   // the original affine map.
456   int64_t maxDim = -1;
457   int64_t maxSym = -1;
458   getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym);
459   unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims());
460   unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols());
461   return AffineMap::get(numDims, numSyms, resultExprs, context);
462 }
463 
464 // tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a
465 // producer). Fusing when operand has higher rank will require use of mods and
466 // divs in the indexing maps of the fused op which would make it non-invertible.
467 static bool isTensorReshapeOpFoldableByLinearization(
468     tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
469   if (!asProducer)
470     return false;
471   return useIndexMap.isPermutation();
472 }
473 
474 // tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a
475 // consumer).
476 static bool
477 isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp,
478                                          AffineMap useIndexMap,
479                                          bool asProducer) {
480   if (asProducer)
481     return false;
482   return useIndexMap.isPermutation();
483 }
484 
485 /// Check if the reshape operation is only expansion into/collapsing of
486 /// unit-dimension.
487 template <typename TensorReshapeOp>
488 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
489   constexpr bool isExpanding =
490       std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
491   ArrayRef<int64_t> expandedShape =
492       (isExpanding ? reshapeOp.getResultType().getShape()
493                    : reshapeOp.getSrcType().getShape());
494   for (auto &indices : reshapeOp.getReassociationIndices()) {
495     unsigned numUnitDims = 0;
496     for (int64_t position : indices)
497       if (expandedShape[position] == 1)
498         numUnitDims++;
499     if (numUnitDims != indices.size() - 1)
500       return false;
501   }
502   return true;
503 }
504 
505 namespace {
506 /// Pattern to fold tensor_expand_shape op with its consumer by using the source
507 /// of the reshape op as the operand in the consumer (instead of the result of
508 /// the tensor_collapse_shape). The corresponding index map in the consumer
509 /// needs to be modified to linearize the folded dimension.
510 ///
511 /// For example,
512 ///
513 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
514 /// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
515 ///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
516 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
517 ///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
518 ///        -> tensor<?x?x4x?xf32>
519 ///
520 /// can be folded into
521 ///
522 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
523 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
524 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
525 ///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
526 ///        -> tensor<?x?x4x?xf32>
527 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
528 struct FoldProducerReshapeOpByLinearization
529     : public OpRewritePattern<GenericOp> {
530   using OpRewritePattern<GenericOp>::OpRewritePattern;
531 
532   LogicalResult matchAndRewrite(GenericOp genericOp,
533                                 PatternRewriter &rewriter) const override {
534     if (!genericOp.hasTensorSemantics())
535       return failure();
536     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
537     for (const auto &en : llvm::enumerate(inputOperands)) {
538       auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
539       if (!reshapeOp)
540         continue;
541 
542       if (!isTensorReshapeOpFoldableByLinearization(
543               reshapeOp, genericOp.getTiedIndexingMap(en.value()),
544               /*asProducer =*/true) ||
545           (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
546         continue;
547 
548       // Compute the fused operands list,
549       SmallVector<Value> fusedOperands = genericOp.getInputOperands();
550       fusedOperands[en.index()] = reshapeOp.src();
551       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
552       llvm::append_range(fusedOperands, outputOperands);
553 
554       // Compute indexing_maps for the fused operation. The indexing_maps for
555       // the operands of the consumers that arent fused are the same.
556       SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
557 
558       // Compute the indexing map to use for the result of the producer.
559       AffineMap modifiedMap =
560           linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
561       // The modified map cannot have symbols.
562       if (modifiedMap.getNumSymbols())
563         return failure();
564       for (AffineExpr expr : modifiedMap.getResults()) {
565         if (!expr.isPureAffine())
566           return failure();
567       }
568       fusedIndexMaps[en.index()] = modifiedMap;
569 
570       // Further check that the resulting index maps can be fused and
571       // inverted. Without this the resultant op is not legal.
572       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
573         return rewriter.notifyMatchFailure(
574             genericOp, "fused op loop bound computation failed");
575       }
576 
577       rewriter.startRootUpdate(genericOp);
578       genericOp->setOperands(fusedOperands);
579       genericOp.indexing_mapsAttr(
580           rewriter.getAffineMapArrayAttr(fusedIndexMaps));
581       rewriter.finalizeRootUpdate(genericOp);
582       return success();
583     }
584     return failure();
585   }
586 };
587 
588 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
589 /// producer. The corresponding index map in the consumer needs to be modified
590 /// to linearize the folded dimension.
591 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
592 struct FoldConsumerReshapeOpByLinearization
593     : public OpRewritePattern<TensorReshapeOp> {
594   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
595 
596   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
597                                 PatternRewriter &rewriter) const override {
598     GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
599     if (!producer || !producer.hasTensorSemantics() ||
600         producer.getNumOutputs() != 1 ||
601         !isTensorReshapeOpFoldableByLinearization(
602             reshapeOp,
603             producer.getTiedIndexingMap(producer.getOutputOperand(0)),
604             /*asProducer =*/false) ||
605         (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
606       return failure();
607     // The indexing_maps for the operands of the fused operation are same as
608     // those for the operands of the producer.
609     SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
610 
611     // Compute the indexing map to use for the operand of the producer.
612     AffineMap modifiedMap = linearizeCollapsedDims(
613         producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
614     for (AffineExpr expr : modifiedMap.getResults()) {
615       if (!expr.isPureAffine()) {
616         return rewriter.notifyMatchFailure(
617             producer, "fused op indexing map is not affine");
618       }
619     }
620     fusedIndexMaps.back() = modifiedMap;
621 
622     // Further check that the resulting index maps can be fused and
623     // inverted. Without this the resultant op is not legal.
624     if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
625       return rewriter.notifyMatchFailure(
626           producer, "fused op loop bound computation failed");
627     }
628 
629     Location loc = producer.getLoc();
630     SmallVector<Value> inputOperands = producer.getInputOperands();
631     Value output = rewriter.create<TensorReshapeOp>(
632         loc, producer.getOutputOperand(0)->get(),
633         reshapeOp.getReassociationExprs());
634     auto fusedOp = rewriter.create<GenericOp>(
635         loc, reshapeOp.getResultType(),
636         /*inputs=*/inputOperands,
637         // TODO: handle outputs.
638         /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
639         producer.iterator_types(),
640         /*doc=*/nullptr,
641         /*library_call=*/nullptr);
642     auto &fusedRegion = fusedOp->getRegion(0);
643     rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
644                                fusedRegion.begin());
645     rewriter.replaceOp(reshapeOp, fusedOp->getResults());
646     return success();
647   }
648 };
649 } // namespace
650 
651 //===---------------------------------------------------------------------===//
652 // Methods and patterns that fuse reshape ops with elementwise operations by
653 // expanding the dimensionality of the elementwise operations.
654 //===---------------------------------------------------------------------===//
655 
656 /// Conditions for folding a generic operation with a reshape op by expanding
657 /// the iteration space dimensionality for tensor operations. These are
658 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
659 /// following fusion pattern.
660 ///
661 ///  Consider
662 ///
663 ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
664 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
665 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
666 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
667 ///  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
668 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
669 ///
670 ///  The reshape can be folded into the `genericOp` if its loop dimensionality
671 ///  is increased to match the result (operand) of the tensor_expand_shape.
672 ///  The indexing_map of the fused tensor in the `genericOp` and the
673 ///  reassociation map helps compute the indexing maps of the modified op.
674 ///  For the above example, based on the reassociation map it
675 ///  can be concluded that
676 ///
677 ///  - The loop used to access the first dimension of the fused tensor is split
678 ///    into two.
679 ///  - The loop used to access the second dimension of the fused tensor is kept
680 ///    as is.
681 ///  - The loop used to access the third dimension of the fused tensor is split
682 ///    into three.
683 ///
684 ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
685 ///  op, then
686 ///
687 ///   d0 -> e0, e1
688 ///   d1 -> e2, e3, e4
689 ///   d2 -> e5
690 ///
691 ///  substituting this, the generic op can be rewritten as
692 ///
693 ///  %d = linalg.generic ins(%0, %1 : )
694 ///        indexing_maps =
695 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
696 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
697 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
698 ///
699 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
700 ///  to make it consistent
701 ///
702 ///  %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
703 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
704 ///  %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
705 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
706 ///
707 ///  The added reshapes are again expanding patterns, so they will get fused
708 ///  with its producers if possible.
709 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
710                                                OpOperand *fusableOpOperand) {
711   // Is fusable only if:
712   // - All the indexing maps for operands and results are projected
713   //   permutations.
714   // - The fused tensor is not a scalar.
715   // - All the loops are parallel loops.
716   return genericOp.hasTensorSemantics() &&
717          llvm::all_of(genericOp.indexing_maps().getValue(),
718                       [](Attribute attr) {
719                         return attr.cast<AffineMapAttr>()
720                             .getValue()
721                             .isProjectedPermutation();
722                       }) &&
723          genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
724          llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
725            return attr.cast<StringAttr>().getValue() ==
726                   getParallelIteratorTypeName();
727          });
728 }
729 
730 namespace {
731 /// Information needed to expand a generic operation to fold the reshape with
732 /// it.
733 class ExpansionInfo {
734 public:
735   // Computes the mapping from original dimensions of the op to the dimensions
736   // of the expanded op given the `indexingMap` of the fused operand/result of
737   // the generic op, the `reassocationMaps` of the reshape op and the shape of
738   // the expanded op.
739   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
740                         ArrayRef<AffineMap> reassociationMaps,
741                         ArrayRef<int64_t> expandedShape,
742                         ArrayRef<int64_t> collapsedShape,
743                         PatternRewriter &rewriter);
744   unsigned getOrigOpNumDims() const { return reassociation.size(); }
745   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
746   ReassociationIndicesRef getExpandedDims(unsigned i) const {
747     return reassociation[i];
748   }
749   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
750     return expandedShapeMap[i];
751   }
752   ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
753 
754 private:
755   /// Reassociation from the dimensions in the original operation to the
756   /// dimension of the expanded operation.
757   SmallVector<ReassociationIndices> reassociation;
758   /// Mapping from extent of loops in the original operation, to the extent of
759   /// loops in the expanded operation.
760   SmallVector<SmallVector<int64_t>> expandedShapeMap;
761   /// Extent of the loop in the original operation.
762   SmallVector<int64_t> originalLoopExtent;
763   unsigned expandedOpNumDims;
764 };
765 } // namespace
766 
767 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
768                                      OpOperand *fusableOpOperand,
769                                      ArrayRef<AffineMap> reassociationMaps,
770                                      ArrayRef<int64_t> expandedShape,
771                                      ArrayRef<int64_t> collapsedShape,
772                                      PatternRewriter &rewriter) {
773   if (reassociationMaps.empty())
774     return failure();
775   AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
776 
777   Optional<SmallVector<int64_t, 4>> originalLoopRange =
778       linalgOp.getStaticLoopRanges();
779   if (!originalLoopRange)
780     return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
781   originalLoopExtent.assign(originalLoopRange->begin(),
782                             originalLoopRange->end());
783 
784   reassociation.clear();
785   expandedShapeMap.clear();
786   // Compute the number of dimension in the expanded op that correspond to each
787   // dimension of the original op.
788   SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
789   expandedShapeMap.resize(fusedIndexMap.getNumDims());
790   for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
791     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
792     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
793     numExpandedDims[pos] = foldedDims.getNumResults();
794     ArrayRef<int64_t> shape =
795         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
796     expandedShapeMap[pos].assign(shape.begin(), shape.end());
797   }
798   // The remaining dimensions remain the same.
799   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
800     if (expandedShapeMap[i].empty())
801       expandedShapeMap[i] = {originalLoopExtent[i]};
802 
803   // Compute reassociation map from the original op to the expanded op.
804   unsigned sum = 0;
805   reassociation.reserve(fusedIndexMap.getNumDims());
806   for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
807     auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
808     reassociation.emplace_back(seq.begin(), seq.end());
809     sum += numFoldedDim.value();
810   }
811   expandedOpNumDims = sum;
812   return success();
813 }
814 
815 /// Epanding the body of a linalg operation requires adaptations of the accessed
816 /// loop indices. Specifically, access of indices in the original operation need
817 /// to be replaced with linearizations of indices in the expanded op. That
818 /// requires the shape of the expanded dimensions to be static (at least all but
819 /// the most significant). For now check that these are all statically sized.
820 /// Note that this could be extended to handle dynamic case, but the
821 /// implementation below uses `affine.apply` which seems to have issues when the
822 /// shapes are not static.
823 static LogicalResult isGenericOpExpandable(GenericOp genericOp,
824                                            const ExpansionInfo &expansionInfo,
825                                            PatternRewriter &rewriter) {
826   if (!genericOp.hasIndexSemantics())
827     return success();
828   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
829     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
830     if (expandedShape.size() == 1)
831       continue;
832     for (int64_t shape : expandedShape.drop_front()) {
833       if (ShapedType::isDynamic(shape)) {
834         return rewriter.notifyMatchFailure(
835             genericOp, "cannot expand due to index semantics and dynamic dims");
836       }
837     }
838   }
839   return success();
840 }
841 
842 /// Return the indexing map to use in the expanded op for a given the
843 /// `indexingMap` of the original operation.
844 static AffineMap
845 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
846                            const ExpansionInfo &expansionInfo) {
847   SmallVector<AffineExpr> newExprs;
848   for (AffineExpr expr : indexingMap.getResults()) {
849     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
850     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
851         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
852           return builder.getAffineDimExpr(static_cast<unsigned>(v));
853         }));
854     newExprs.append(expandedExprs.begin(), expandedExprs.end());
855   }
856   return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
857                         indexingMap.getNumSymbols(), newExprs,
858                         builder.getContext());
859 }
860 
861 /// Return the type of the operand/result to use in the expanded op given the
862 /// type in the original op.
863 static RankedTensorType getExpandedType(RankedTensorType originalType,
864                                         AffineMap indexingMap,
865                                         const ExpansionInfo &expansionInfo) {
866   SmallVector<int64_t> expandedShape;
867   for (AffineExpr expr : indexingMap.getResults()) {
868     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
869     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
870     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
871   }
872   return RankedTensorType::get(expandedShape, originalType.getElementType());
873 }
874 
875 /// Returns the reassociation maps to use in the `tensor.expand_shape`
876 /// operation to convert the operands of the original operation to operands of
877 /// the expanded operation. The same method is used to compute the
878 /// `tensor.collapse_shape` used to collapse the result of the expanded
879 /// op to get the value that can replace all uses of the results of the original
880 /// op.
881 static SmallVector<ReassociationIndices>
882 getReassociationForExpansion(AffineMap indexingMap,
883                              const ExpansionInfo &expansionInfo) {
884   SmallVector<ReassociationIndices> reassociation;
885   unsigned numReshapeDims = 0;
886   for (AffineExpr expr : indexingMap.getResults()) {
887     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
888     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
889     SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
890         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
891     reassociation.emplace_back(std::move(indices));
892     numReshapeDims += numExpandedDims;
893   }
894   return reassociation;
895 }
896 
897 /// Update the body of an expanded linalg operation having index semantics. The
898 /// indices of the original operation need to be recovered by linearizing the
899 /// indices of the correspoding dimensions of the expanded operation. For now it
900 /// is assumed that the shapes of the expanded operation needed for
901 /// linearization are static.
902 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
903                                           Location loc, Region &fusedRegion,
904                                           const ExpansionInfo &expansionInfo) {
905   // Replace the original indices by the linearization of the expanded indices.
906   for (IndexOp indexOp :
907        llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
908     ArrayRef<int64_t> expandedDims =
909         expansionInfo.getExpandedDims(indexOp.dim());
910     assert(!expandedDims.empty() && "expected valid expansion info");
911 
912     // Skip index operations that are not affected by the expansion.
913     if (expandedDims.size() == 1 &&
914         expandedDims.front() == (int64_t)indexOp.dim())
915       continue;
916 
917     // Linearize the expanded indices of the original index dimension.
918     OpBuilder::InsertionGuard guard(rewriter);
919     rewriter.setInsertionPointAfter(indexOp);
920     ArrayRef<int64_t> expandedDimsShape =
921         expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
922     SmallVector<Value> expandedIndices;
923     expandedIndices.reserve(expandedDims.size() - 1);
924     llvm::transform(
925         expandedDims.drop_front(), std::back_inserter(expandedIndices),
926         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
927     Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
928     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
929       assert(!ShapedType::isDynamic(std::get<0>(it)));
930       AffineExpr idx, acc;
931       bindDims(rewriter.getContext(), idx, acc);
932       newIndex = rewriter.create<AffineApplyOp>(
933           indexOp.getLoc(), idx + acc * std::get<0>(it),
934           ValueRange{std::get<1>(it), newIndex});
935     }
936     rewriter.replaceOp(indexOp, newIndex);
937   }
938 }
939 
940 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op
941 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
942 /// that those conditions have been satisfied.
943 static Optional<SmallVector<Value>>
944 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
945                            OpOperand *fusableOpOperand,
946                            PatternRewriter &rewriter) {
947   assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
948          "preconditions for fuse operation failed");
949   // Check if reshape is expanding or collapsing.
950   auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
951   auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
952   bool isExpanding = (expandingReshapeOp != nullptr);
953   RankedTensorType expandedType = isExpanding
954                                       ? expandingReshapeOp.getResultType()
955                                       : collapsingReshapeOp.getSrcType();
956   RankedTensorType collapsedType = isExpanding
957                                        ? expandingReshapeOp.getSrcType()
958                                        : collapsingReshapeOp.getResultType();
959 
960   ExpansionInfo expansionInfo;
961   if (failed(expansionInfo.compute(
962           genericOp, fusableOpOperand,
963           isExpanding ? expandingReshapeOp.getReassociationMaps()
964                       : collapsingReshapeOp.getReassociationMaps(),
965           expandedType.getShape(), collapsedType.getShape(), rewriter)))
966     return llvm::None;
967 
968   if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
969     return llvm::None;
970 
971   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
972       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
973         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
974       }));
975 
976   SmallVector<Value> expandedOpOperands;
977   expandedOpOperands.reserve(genericOp.getNumInputs());
978   for (OpOperand *opOperand : genericOp.getInputOperands()) {
979     if (opOperand == fusableOpOperand) {
980       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
981                                                : collapsingReshapeOp.src());
982       continue;
983     }
984     if (genericOp.isInputTensor(opOperand)) {
985       AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
986       auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
987       RankedTensorType expandedOperandType =
988           getExpandedType(opOperandType, indexingMap, expansionInfo);
989       if (expandedOperandType != opOperand->get().getType()) {
990         // Reshape the operand to get the right type.
991         SmallVector<ReassociationIndices> reassociation =
992             getReassociationForExpansion(indexingMap, expansionInfo);
993         if (failed(reshapeLikeShapesAreCompatible(
994                 [&](const Twine &msg) {
995                   return rewriter.notifyMatchFailure(genericOp, msg);
996                 },
997                 opOperandType.getShape(), expandedOperandType.getShape(),
998                 reassociation,
999                 /*isExpandingReshape=*/true)))
1000           return llvm::None;
1001         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
1002             genericOp.getLoc(), expandedOperandType, opOperand->get(),
1003             reassociation));
1004         continue;
1005       }
1006     }
1007     expandedOpOperands.push_back(opOperand->get());
1008   }
1009 
1010   Location loc = genericOp.getLoc();
1011   SmallVector<Value> outputs;
1012   for (OpOperand *opOperand : genericOp.getOutputOperands()) {
1013     AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
1014     auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
1015     RankedTensorType expandedOutputType =
1016         getExpandedType(opOperandType, indexingMap, expansionInfo);
1017     if (expandedOutputType != opOperand->get().getType()) {
1018       SmallVector<ReassociationIndices> reassociation =
1019           getReassociationForExpansion(indexingMap, expansionInfo);
1020       if (failed(reshapeLikeShapesAreCompatible(
1021               [&](const Twine &msg) {
1022                 return rewriter.notifyMatchFailure(genericOp, msg);
1023               },
1024               opOperandType.getShape(), expandedOutputType.getShape(),
1025               reassociation,
1026               /*isExpandingReshape=*/true)))
1027         return llvm::None;
1028       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
1029           genericOp.getLoc(), expandedOutputType, opOperand->get(),
1030           reassociation));
1031     }
1032   }
1033 
1034   // The iterator types of the expanded op are all parallel.
1035   SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
1036                                        getParallelIteratorTypeName());
1037 
1038   TypeRange resultTypes = ValueRange(outputs).getTypes();
1039   auto fusedOp =
1040       rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
1041                                  /*inputs=*/expandedOpOperands, outputs,
1042                                  expandedOpIndexingMaps, iteratorTypes);
1043   Region &fusedRegion = fusedOp->getRegion(0);
1044   Region &originalRegion = genericOp->getRegion(0);
1045   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
1046 
1047   // Update the index accesses after the expansion.
1048   updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
1049 
1050   // Reshape the result values to their original shape if this is a collapsing
1051   // reshape folded into its consumer.
1052   SmallVector<Value> resultVals;
1053   for (OpResult opResult : genericOp->getOpResults()) {
1054     int64_t resultNumber = opResult.getResultNumber();
1055     if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
1056       SmallVector<ReassociationIndices> reassociation =
1057           getReassociationForExpansion(
1058               genericOp.getTiedIndexingMap(
1059                   genericOp.getOutputOperand(resultNumber)),
1060               expansionInfo);
1061       resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
1062           genericOp.getLoc(), opResult.getType(),
1063           fusedOp->getResult(resultNumber), reassociation));
1064     } else {
1065       resultVals.push_back(fusedOp->getResult(resultNumber));
1066     }
1067   }
1068   // Assuming a single result.
1069   return resultVals;
1070 }
1071 
1072 namespace {
1073 
1074 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
1075 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
1076 /// in the consumer is expanded.
1077 class FoldWithProducerReshapeOpByExpansion
1078     : public OpRewritePattern<GenericOp> {
1079 public:
1080   FoldWithProducerReshapeOpByExpansion(
1081       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1082       PatternBenefit benefit = 1)
1083       : OpRewritePattern<GenericOp>(context, benefit),
1084         controlFoldingReshapes(std::move(foldReshapes)) {}
1085 
1086   LogicalResult matchAndRewrite(GenericOp genericOp,
1087                                 PatternRewriter &rewriter) const override {
1088     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1089       tensor::CollapseShapeOp reshapeOp =
1090           opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1091       if (!reshapeOp)
1092         continue;
1093       // Fold only if
1094       // - The tensor reshape op is folding.
1095       // - All constraints of fusing with reshape by expansion are met.
1096       if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
1097           (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
1098         continue;
1099 
1100       Optional<SmallVector<Value>> replacementValues =
1101           fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
1102       if (!replacementValues)
1103         return failure();
1104       rewriter.replaceOp(genericOp, replacementValues.getValue());
1105       return success();
1106     }
1107     return failure();
1108   }
1109 
1110 private:
1111   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1112 };
1113 
1114 /// Pattern to fold a tensor_expand_shape op with its producer generic op
1115 /// by expanding the dimensionality of the loop in the producer op.
1116 struct FoldReshapeWithGenericOpByExpansion
1117     : public OpRewritePattern<tensor::ExpandShapeOp> {
1118 
1119   FoldReshapeWithGenericOpByExpansion(
1120       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1121       PatternBenefit benefit = 1)
1122       : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1123         controlFoldingReshapes(std::move(foldReshapes)) {}
1124 
1125   LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1126                                 PatternRewriter &rewriter) const override {
1127     // Fold only if all constraints of fusing with reshape by expansion are met.
1128     GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
1129     if (!producer || producer.getNumOutputs() != 1 ||
1130         !isFusableWithReshapeByDimExpansion(producer,
1131                                             producer.getOutputOperand(0)) ||
1132         !controlFoldingReshapes(producer->getResult(0),
1133                                 reshapeOp->getOpOperand(0)))
1134       return failure();
1135     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
1136         producer, reshapeOp, producer.getOutputOperand(0), rewriter);
1137     if (!replacementValues)
1138       return failure();
1139     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
1140     return success();
1141   }
1142 
1143 private:
1144   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1145 };
1146 } // namespace
1147 
1148 //===---------------------------------------------------------------------===//
1149 // Methods and patterns to fuse reshape with linalg.generic operations by
1150 // contraction of dimensions.
1151 //===---------------------------------------------------------------------===//
1152 
1153 /// For an `indexingMap` that is a projected permutation, if the range is to be
1154 /// collapsed using the given `reassociation`, get the reassociation in the
1155 /// domain that would keep the map a projected permutation.
1156 static SmallVector<ReassociationIndices>
1157 getDomainReassociation(AffineMap indexingMap,
1158                        ArrayRef<ReassociationIndices> rangeReassociation) {
1159   assert(indexingMap.isProjectedPermutation() &&
1160          "expected projected permutation map");
1161   unsigned counter = 0;
1162   SmallVector<ReassociationIndices> domainReassociation;
1163   llvm::SmallDenseSet<unsigned, 4> processedDomainDims;
1164   // Iterate over the reassociation indices.
1165   for (ReassociationIndicesRef foldedRangeDims : rangeReassociation) {
1166     ReassociationIndices foldedDomainDims;
1167     for (auto rangeDim : foldedRangeDims) {
1168       (void)rangeDim;
1169       AffineDimExpr dimExpr =
1170           indexingMap.getResult(counter++).cast<AffineDimExpr>();
1171       foldedDomainDims.push_back(dimExpr.getPosition());
1172       processedDomainDims.insert(dimExpr.getPosition());
1173     }
1174     domainReassociation.emplace_back(std::move(foldedDomainDims));
1175   }
1176   // Fill in the missing domain dims.
1177   for (auto dim : llvm::seq<unsigned>(0, indexingMap.getNumDims())) {
1178     if (processedDomainDims.count(dim))
1179       continue;
1180     ReassociationIndices vec = {dim};
1181     domainReassociation.emplace_back(std::move(vec));
1182   }
1183 
1184   // Sort the reassociation using the first dimension of the folded range to
1185   // not create unnecessary transposes.
1186   llvm::sort(domainReassociation,
1187              [](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
1188                return lhs[0] < rhs[0];
1189              });
1190   return domainReassociation;
1191 }
1192 
1193 /// For a given `dimSequence`, check if the sequence is conserved in the
1194 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1195 /// Non-existence of the sequence returns true as well.
1196 static bool isDimSequencePreserved(AffineMap indexingMap,
1197                                    ReassociationIndicesRef dimSequence) {
1198   assert(!dimSequence.empty() &&
1199          "expected non-empty list for dimension sequence");
1200   assert(indexingMap.isProjectedPermutation() &&
1201          "expected indexing map to be projected permutation");
1202 
1203   llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1204   sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1205 
1206   unsigned dimSequenceStart = dimSequence[0];
1207   for (auto expr : enumerate(indexingMap.getResults())) {
1208     unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition();
1209     // 1.  Check if this start of the sequence.
1210     if (dimInMapStart == dimSequenceStart) {
1211       if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1212         return false;
1213       // 1a. Check if sequence is preserved.
1214       for (auto dimInSequence : enumerate(dimSequence)) {
1215         unsigned dimInMap =
1216             indexingMap.getResult(expr.index() + dimInSequence.index())
1217                 .cast<AffineDimExpr>()
1218                 .getPosition();
1219         if (dimInMap != dimInSequence.value())
1220           return false;
1221       }
1222       // Found the sequence. Projected permutation
1223       // enforces that all AffineDimExprs in the result are unique, so no
1224       // further checks are needed.
1225       return true;
1226     }
1227     // 2. If position in the expr (which is of type AffineDimExpr) is part
1228     // of sequence, return false here. This implies the entire sequence does not
1229     // exist in the indexing map.
1230     if (sequenceElements.count(dimInMapStart))
1231       return false;
1232   }
1233   // 3. No element of sequence found. Return true.
1234   return true;
1235 }
1236 
1237 // Check if a generic op can be fused along an operand by collapsing dimensions.
1238 static bool isFusableWithReshapeByDimCollapse(
1239     GenericOp genericOp, OpOperand *fusableOperand,
1240     ArrayRef<ReassociationIndices> reassociation) {
1241   // Some basic checks for this fusion to be valid.
1242   if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1)
1243     return false;
1244 
1245   if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) {
1246         return map.isProjectedPermutation();
1247       })) {
1248     return false;
1249   }
1250 
1251   // Get the reassociation for the iteration space.
1252   SmallVector<ReassociationIndices> iterationReassociation =
1253       getDomainReassociation(genericOp.getTiedIndexingMap(fusableOperand),
1254                              reassociation);
1255   if (iterationReassociation.empty()) {
1256     // If the domain reassociation indices is empty, then this is a scalar op.
1257     // Nothing to do.
1258     return false;
1259   }
1260 
1261   auto iteratorTypes = genericOp.iterator_types().getValue();
1262   ArrayRef<Attribute> iteratorTypesRef(iteratorTypes);
1263   for (ReassociationIndicesRef foldedIterDims : iterationReassociation) {
1264     // Check that for all indexing maps, the folded dimensions sequence is
1265     // preserved.
1266     if (!llvm::all_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) {
1267           return isDimSequencePreserved(indexingMap, foldedIterDims);
1268         }))
1269       return false;
1270     unsigned startDim = foldedIterDims[0];
1271     ArrayRef<Attribute> foldedIteratorTypes =
1272         iteratorTypesRef.drop_front(startDim).take_front(foldedIterDims.size());
1273     // Check that all folded iterator types are either all parallel, or all
1274     // reduction.
1275     if (!llvm::all_of(
1276             foldedIteratorTypes,
1277             [](Attribute attr) { return isParallelIterator(attr); }) &&
1278         !llvm::all_of(foldedIteratorTypes,
1279                       [](Attribute attr) { return isReductionIterator(attr); }))
1280       return false;
1281   }
1282   return true;
1283 }
1284 
1285 /// Helper class to carry state while collapsing the `linalg.generic` op.
1286 namespace {
1287 class CollapsingInfo {
1288 public:
1289   CollapsingInfo(SmallVector<ReassociationIndices> &&reassociation) {
1290     iterationReassociation = std::move(reassociation);
1291     for (auto foldedIterDims : enumerate(iterationReassociation)) {
1292       foldedDimStartToSequenceMap[foldedIterDims.value()[0]] =
1293           foldedIterDims.index();
1294     }
1295   }
1296 
1297   // Returns the iteration space reassociation.
1298   ArrayRef<ReassociationIndices> getReassociationIndices() {
1299     return iterationReassociation;
1300   }
1301 
1302   // Returns true if the given dimension is the start of a sequence of folded
1303   // dimensions.
1304   bool isDimStartOfFoldedDims(unsigned dim) {
1305     return foldedDimStartToSequenceMap.count(dim);
1306   }
1307 
1308   // Return the folded dimensions starting at `dim`.
1309   ReassociationIndicesRef getFoldedDimsStartingAt(unsigned dim) {
1310     assert(foldedDimStartToSequenceMap.count(dim) &&
1311            "invalid start dim of folded dim "
1312            "sequence");
1313     return iterationReassociation[foldedDimStartToSequenceMap[dim]];
1314   }
1315 
1316   // For a dim in the original op, return the dim in the collapsed op, that it
1317   // is mapped to. Expectes `dim` to be start of a folded dimension sequence.
1318   unsigned getDimInCollapsedOpForStartOfFoldedDims(unsigned dim) {
1319     assert(foldedDimStartToSequenceMap.count(dim) &&
1320            "invalid start dim of folded dim sequence");
1321     return foldedDimStartToSequenceMap[dim];
1322   }
1323 
1324 private:
1325   /// Reassociation describing the folded iteration space dimensions.
1326   SmallVector<ReassociationIndices> iterationReassociation;
1327 
1328   /// Map from the starting dimensions of the folded dimension sequences to
1329   /// their index in `iterationReassociation`.
1330   llvm::DenseMap<unsigned, unsigned> foldedDimStartToSequenceMap;
1331 };
1332 } // namespace
1333 
1334 /// Get the iterator types for the collapsed operation given the original
1335 /// iterator types and collapsed dimensions.
1336 static SmallVector<StringRef>
1337 getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
1338                             CollapsingInfo &collapsingInfo) {
1339   SmallVector<StringRef> collapsedIteratorTypes;
1340   for (ReassociationIndicesRef foldedIterDims :
1341        collapsingInfo.getReassociationIndices()) {
1342     assert(!foldedIterDims.empty() &&
1343            "reassociation indices expected to have non-empty sets");
1344     // Just pick the iterator type of the first folded dim. Pre-condition checks
1345     // expected to have checked that iterator types of all folded dimensions are
1346     // the same.
1347     collapsedIteratorTypes.push_back(
1348         iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue());
1349   }
1350   return collapsedIteratorTypes;
1351 }
1352 
1353 /// Compute the indexing map in the collapsed op that corresponds to the given
1354 /// `indexingMap` of the original operation.
1355 static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap,
1356                                            CollapsingInfo &collapsingInfo) {
1357   MLIRContext *context = indexingMap.getContext();
1358   assert(indexingMap.isProjectedPermutation() &&
1359          "expected indexing map to be projected permutation");
1360   SmallVector<AffineExpr> resultExprs;
1361   for (auto expr : indexingMap.getResults()) {
1362     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
1363     if (collapsingInfo.isDimStartOfFoldedDims(dim)) {
1364       resultExprs.push_back(getAffineDimExpr(
1365           collapsingInfo.getDimInCollapsedOpForStartOfFoldedDims(dim),
1366           context));
1367     }
1368   }
1369   return AffineMap::get(collapsingInfo.getReassociationIndices().size(), 0,
1370                         resultExprs, context);
1371 }
1372 
1373 /// Return the `reassociation` indices to use to collapse the operand when the
1374 /// iteration space of a generic op is collapsed.
1375 static SmallVector<ReassociationIndices>
1376 getOperandReassociation(AffineMap indexingMap, CollapsingInfo &collapsingInfo) {
1377   unsigned counter = 0;
1378   SmallVector<ReassociationIndices> operandReassociation;
1379   for (auto expr : indexingMap.getResults()) {
1380     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
1381     if (collapsingInfo.isDimStartOfFoldedDims(dim)) {
1382       unsigned numFoldedDims =
1383           collapsingInfo.getFoldedDimsStartingAt(dim).size();
1384       auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1385       operandReassociation.emplace_back(range.begin(), range.end());
1386       counter += numFoldedDims;
1387     }
1388   }
1389   return operandReassociation;
1390 }
1391 
1392 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1393 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
1394                                    OpOperand *opOperand,
1395                                    CollapsingInfo &collapsingInfo,
1396                                    OpBuilder &builder) {
1397   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
1398   SmallVector<ReassociationIndices> operandReassociation =
1399       getOperandReassociation(indexingMap, collapsingInfo);
1400 
1401   // If the number of entries in the reassocation for the operand is same as the
1402   // number of results of the indexing map, then nothing to do for this operand.
1403   Value operand = opOperand->get();
1404   if (operandReassociation.size() == indexingMap.getNumResults())
1405     return operand;
1406 
1407   // Insert a reshape to collapse the dimensions.
1408   auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
1409       loc, operand, operandReassociation);
1410   return reshapeOp.getResult();
1411 }
1412 
1413 /// Modify the `linalg.index` operations in the original generic op, to its
1414 /// value in the collapsed operation.
1415 void generateCollapsedIndexingRegion(Location loc, Block *block,
1416                                      CollapsingInfo &collapsingInfo,
1417                                      ValueRange loopRange,
1418                                      PatternRewriter &rewriter) {
1419   OpBuilder::InsertionGuard g(rewriter);
1420   rewriter.setInsertionPointToStart(block);
1421 
1422   // Collect all the original index ops.
1423   auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1424 
1425   // For each folded dimension list resolve the original induction variable
1426   // values in terms of the folded dimension induction variable.
1427   //   i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1428   // can be inverted to
1429   //   i2 = i_{folded} % d2
1430   //   i1 = (i_{folded} / d2) % d1
1431   //   i0 = i_{folded} / (d1 * d2)
1432   llvm::DenseMap<unsigned, Value> indexReplacementVals;
1433   for (auto &foldedDims : enumerate(collapsingInfo.getReassociationIndices())) {
1434     ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1435     Value newIndexVal =
1436         rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1437     for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1438       indexReplacementVals[dim] =
1439           rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1440       newIndexVal =
1441           rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1442     }
1443     indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1444   }
1445 
1446   for (auto indexOp : indexOps) {
1447     auto dim = indexOp.dim();
1448     rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1449   }
1450 }
1451 
1452 /// Implementation of fusion with reshape operation by collapsing dimensions.
1453 static Optional<SmallVector<Value>>
1454 fuseWithReshapeByCollapsing(GenericOp genericOp, Operation *reshapeOp,
1455                             OpOperand *fusableOpOperand,
1456                             PatternRewriter &rewriter) {
1457   SmallVector<ReassociationIndices> reassociation =
1458       isa<tensor::CollapseShapeOp>(reshapeOp)
1459           ? cast<tensor::CollapseShapeOp>(reshapeOp).getReassociationIndices()
1460           : cast<tensor::ExpandShapeOp>(reshapeOp).getReassociationIndices();
1461   assert(isFusableWithReshapeByDimCollapse(genericOp, fusableOpOperand,
1462                                            reassociation) &&
1463          "preconditions for fusing with reshape by collapse failed");
1464 
1465   CollapsingInfo collapsingInfo(getDomainReassociation(
1466       genericOp.getTiedIndexingMap(fusableOpOperand), reassociation));
1467   // Check for trivial no transformation cases. In that case return nothing.
1468   if (collapsingInfo.getReassociationIndices().size() ==
1469       genericOp.getNumLoops())
1470     return llvm::None;
1471 
1472   // Get the iterator types for the operand.
1473   SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes(
1474       genericOp.iterator_types().getValue(), collapsingInfo);
1475 
1476   // Get the indexing maps.
1477   auto indexingMaps = llvm::to_vector(
1478       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap map) {
1479         return getCollapsedOpIndexingMap(map, collapsingInfo);
1480       }));
1481 
1482   Location loc =
1483       rewriter.getFusedLoc({genericOp->getLoc(), reshapeOp->getLoc()});
1484 
1485   // Get the input operands.
1486   auto inputOperands = llvm::to_vector(
1487       llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) {
1488         return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
1489                                      rewriter);
1490       }));
1491 
1492   // Get the output operands and result types.
1493   SmallVector<Type> resultTypes;
1494   SmallVector<Value> outputOperands;
1495   resultTypes.reserve(genericOp.getNumOutputs());
1496   outputOperands.reserve(genericOp.getNumOutputs());
1497   for (OpOperand *output : genericOp.getOutputOperands()) {
1498     Value newOutput =
1499         getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter);
1500     outputOperands.push_back(newOutput);
1501     resultTypes.push_back(newOutput.getType());
1502   }
1503 
1504   // Create the generic op.
1505   auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
1506       loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1507       iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1508   Block *origOpBlock = &genericOp->getRegion(0).front();
1509   Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
1510   rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1511                        collapsedOpBlock->getArguments());
1512 
1513   if (collapsedGenericOp.hasIndexSemantics()) {
1514     // Collect the loop range of the generic op.
1515     OpBuilder::InsertionGuard g(rewriter);
1516     rewriter.setInsertionPoint(collapsedGenericOp);
1517     SmallVector<Range> loopRanges =
1518         cast<LinalgOp>(genericOp.getOperation())
1519             .createLoopRanges(rewriter, genericOp.getLoc());
1520     assert(llvm::all_of(loopRanges,
1521                         [](Range range) {
1522                           return matchPattern(range.offset, m_Zero()) &&
1523                                  matchPattern(range.stride, m_One());
1524                         }) &&
1525            "expected all loop ranges to have zero start and unit stride");
1526     SmallVector<Value> loopBound = llvm::to_vector(
1527         llvm::map_range(loopRanges, [](Range range) { return range.size; }));
1528     generateCollapsedIndexingRegion(loc,
1529                                     &collapsedGenericOp->getRegion(0).front(),
1530                                     collapsingInfo, loopBound, rewriter);
1531   }
1532 
1533   // Insert expanding reshape for the result to get back the original result
1534   // type.
1535   SmallVector<Value> results;
1536   for (auto originalResult : llvm::enumerate(genericOp->getResults())) {
1537     Value collapsedOpResult =
1538         collapsedGenericOp->getResult(originalResult.index());
1539     auto originalResultType =
1540         originalResult.value().getType().cast<ShapedType>();
1541     auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
1542     if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1543       AffineMap indexingMap =
1544           genericOp.getTiedIndexingMapForResult(originalResult.value());
1545       SmallVector<ReassociationIndices> reassociation =
1546           getOperandReassociation(indexingMap, collapsingInfo);
1547       Value result = rewriter.create<tensor::ExpandShapeOp>(
1548           loc, originalResultType, collapsedOpResult, reassociation);
1549       results.push_back(result);
1550     } else {
1551       results.push_back(collapsedOpResult);
1552     }
1553   }
1554   return results;
1555 }
1556 
1557 namespace {
1558 
1559 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1560 /// contracting dimensions of the loop.
1561 class FoldWithProducerReshapeOpByCollapsing
1562     : public OpRewritePattern<GenericOp> {
1563 public:
1564   FoldWithProducerReshapeOpByCollapsing(
1565       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1566       PatternBenefit benefit = 1)
1567       : OpRewritePattern<GenericOp>(context, benefit),
1568         controlFoldingReshapes(std::move(foldReshapes)) {}
1569 
1570   LogicalResult matchAndRewrite(GenericOp genericOp,
1571                                 PatternRewriter &rewriter) const override {
1572     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1573       tensor::ExpandShapeOp reshapeOp =
1574           opOperand->get().getDefiningOp<tensor::ExpandShapeOp>();
1575       if (!reshapeOp)
1576         continue;
1577 
1578       if (!isFusableWithReshapeByDimCollapse(
1579               genericOp, opOperand, reshapeOp.getReassociationIndices()) ||
1580           !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
1581         continue;
1582       }
1583 
1584       Optional<SmallVector<Value>> replacements = fuseWithReshapeByCollapsing(
1585           genericOp, reshapeOp, opOperand, rewriter);
1586       if (!replacements) {
1587         return rewriter.notifyMatchFailure(
1588             genericOp, "failed to do the fusion by collapsing transformation");
1589       }
1590 
1591       rewriter.replaceOp(genericOp, replacements.getValue());
1592       return success();
1593     }
1594     return failure();
1595   }
1596 
1597 private:
1598   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1599 };
1600 } // namespace
1601 
1602 //===---------------------------------------------------------------------===//
1603 // Methods and patterns to convert tensor.expand_shape -> linalg.generic
1604 // into linalg.generic -> tensor.expand_shape, i.e. push the reshape down.
1605 //===---------------------------------------------------------------------===//
1606 
1607 // TODO(ravishankarm): This pattern is to be deprecated in favor of fusion by
1608 // collapsing that provides a more general functionality. This pattern is very
1609 // specific to a particular use case. The fusion by collapsing can provide the
1610 // same control to clients using the control function there.
1611 
1612 static SmallVector<ReassociationIndices>
1613 getReassociationIndices(ArrayRef<AffineMap> maps) {
1614   SmallVector<ReassociationIndices> reassociation;
1615   for (AffineMap map : maps) {
1616     ReassociationIndices indices;
1617     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
1618       unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
1619       indices.push_back(pos);
1620     }
1621     reassociation.push_back(indices);
1622   }
1623   return reassociation;
1624 }
1625 
1626 namespace {
1627 /// Pattern to move rank reducing reshape after an elementwise linalg generic
1628 /// op. This is useful to expose more fusion opportunities between named ops and
1629 /// generic ops. This can only be done if there is no broadcast or permuation
1630 /// within the dimensions we need to merge.
1631 ///
1632 /// For example,
1633 ///
1634 ///  %0 = tensor.expand_shape %A [[0, 1], [2]]
1635 ///      : tensor<12544x16xf32> into tensor<112x112x16xf32>
1636 ///  %2 = linalg.generic {indexing_maps = [
1637 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1638 ///    affine_map<(d0, d1, d2) -> (d2)>,
1639 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
1640 ///    ["parallel", "parallel", "parallel"]} {
1641 ///  } -> tensor<112x112x16xf32>
1642 ///
1643 ///  into
1644 ///
1645 ///  %2 = linalg.generic {indexing_maps = [
1646 ///    affine_map<(d0, d1) -> (d0, d1)>,
1647 ///    affine_map<(d0, d1) -> (d1)>,
1648 ///    affine_map<(d0, d1) -> (d0, d1)>],
1649 ///    iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
1650 ///    : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
1651 ///  } -> tensor<12544x16xf32>
1652 ///  %3 = tensor.expand_shape %2 [[0, 1], [2]]
1653 ///    : tensor<12544x16xf32> into tensor<112x112x16xf32>
1654 struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
1655   using OpRewritePattern<GenericOp>::OpRewritePattern;
1656 
1657   LogicalResult matchAndRewrite(GenericOp genericOp,
1658                                 PatternRewriter &rewriter) const override {
1659     // Only apply to elementwise linalg on tensor.
1660     if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() ||
1661         genericOp.getNumParallelLoops() != genericOp.getNumLoops())
1662       return failure();
1663     // Only support identity output maps. It could be extended to permuations if
1664     // needed.
1665     if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
1666           return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
1667         }))
1668       return failure();
1669     int64_t destRank = genericOp.getNumParallelLoops();
1670     SmallVector<Value> newOperands = genericOp.getInputOperands();
1671     tensor::ExpandShapeOp reshapeFound;
1672     // 1. Look for tensor_expand_shape operands and figure out save the
1673     // dimensions merged.
1674     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
1675     for (const auto &en : llvm::enumerate(inputOperands)) {
1676       auto reshapeOp =
1677           en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>();
1678       if (!reshapeOp)
1679         continue;
1680       // TODO: We could support non-identity map as long as the merged
1681       // dimensions are still contiguous.
1682       if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
1683         continue;
1684       if (reshapeFound) {
1685         // Only support a second reshape op if it has the same reassociate maps.
1686         if (reshapeFound.getReassociationMaps() ==
1687             reshapeOp.getReassociationMaps())
1688           newOperands[en.index()] = reshapeOp.src();
1689         continue;
1690       }
1691       reshapeFound = reshapeOp;
1692       newOperands[en.index()] = reshapeOp.src();
1693     }
1694     if (!reshapeFound)
1695       return failure();
1696 
1697     // Calculate the reassociation indices and rassociated reverse map.
1698     SmallVector<ReassociationIndices> reassociation =
1699         getReassociationIndices(reshapeFound.getReassociationMaps());
1700     SmallVector<unsigned> remap(destRank);
1701     for (auto &indices : llvm::enumerate(reassociation)) {
1702       for (int64_t index : indices.value()) {
1703         remap[index] = indices.index();
1704       }
1705     }
1706     // 2. Verify that we can merge the dimensions in the linalg and that we
1707     // don't need to create new reshapes operands. Inserting new reshape
1708     // operands would defeat the purpose of the transformation.
1709     for (const auto &en : llvm::enumerate(inputOperands)) {
1710       if (en.value()->get() == newOperands[en.index()]) {
1711         AffineMap map = genericOp.getTiedIndexingMap(en.value());
1712         for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
1713           if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
1714             return failure();
1715         }
1716       }
1717     }
1718 
1719     // 3. Calculate the affine map remapping and the reassociation to apply to
1720     // output tensors.
1721     SmallVector<AffineMap> newMaps;
1722     unsigned newRank = reassociation.size();
1723     for (auto map : genericOp.getIndexingMaps()) {
1724       SmallVector<AffineExpr> newExprs;
1725       for (auto expr : map.getResults()) {
1726         unsigned position = expr.template cast<AffineDimExpr>().getPosition();
1727         // Skip dimension merged except for the last of the group.
1728         if (reassociation[remap[position]].back() == position) {
1729           newExprs.push_back(
1730               getAffineDimExpr(remap[position], genericOp.getContext()));
1731         }
1732       }
1733       newMaps.push_back(
1734           AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
1735     }
1736 
1737     // 4. Reshape the output tensors.
1738     SmallVector<Value> newOutputs;
1739     SmallVector<Type> newOutputTypes;
1740     for (auto output : genericOp.outputs()) {
1741       auto newOutputType = RankedTensorType::get(
1742           reshapeFound.getSrcType().getShape(),
1743           output.getType().template cast<RankedTensorType>().getElementType());
1744       Value newOutput = rewriter.create<tensor::CollapseShapeOp>(
1745           genericOp->getLoc(), newOutputType, output, reassociation);
1746       newOutputTypes.push_back(newOutputType);
1747       newOutputs.push_back(newOutput);
1748     }
1749     // 5. Create a new generic op with lowerer rank.
1750     SmallVector<StringRef> iteratorTypes(newRank,
1751                                          getParallelIteratorTypeName());
1752     auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
1753                                             newOperands, newOutputs, newMaps,
1754                                             iteratorTypes);
1755     rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
1756                                 newOp.region().begin());
1757     // 6. Reshape the so that the type matches the uses.
1758     SmallVector<Value> newResults;
1759     for (const auto &result : llvm::enumerate(newOp->getResults())) {
1760       newResults.push_back(rewriter.create<tensor::ExpandShapeOp>(
1761           genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
1762           result.value(), reassociation));
1763     }
1764     rewriter.replaceOp(genericOp, newResults);
1765     return success();
1766   }
1767 };
1768 } // namespace
1769 
1770 //===---------------------------------------------------------------------===//
1771 // Methods and patterns that fuse constants with linalg.generic operations.
1772 //===---------------------------------------------------------------------===//
1773 
1774 namespace {
1775 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1776 /// handle cases where the constant is not single-valued.
1777 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1778 public:
1779   FoldScalarOrSplatConstant(MLIRContext *context,
1780                             ControlElementwiseOpsFusionFn &fun,
1781                             PatternBenefit benefit = 1)
1782       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
1783 
1784   LogicalResult matchAndRewrite(GenericOp genericOp,
1785                                 PatternRewriter &rewriter) const override {
1786     if (!genericOp.hasTensorSemantics())
1787       return failure();
1788     for (OpOperand *opOperand : genericOp.getInputOperands()) {
1789       Operation *def = opOperand->get().getDefiningOp();
1790       Attribute constantAttr;
1791       auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1792         {
1793           DenseElementsAttr splatAttr;
1794           if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1795               splatAttr.isSplat() &&
1796               splatAttr.getType().getElementType().isIntOrFloat()) {
1797             constantAttr = splatAttr.getSplatValue<Attribute>();
1798             return true;
1799           }
1800         }
1801         {
1802           IntegerAttr intAttr;
1803           if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1804             constantAttr = intAttr;
1805             return true;
1806           }
1807         }
1808         {
1809           FloatAttr floatAttr;
1810           if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1811             constantAttr = floatAttr;
1812             return true;
1813           }
1814         }
1815         return false;
1816       };
1817 
1818       auto resultValue = opOperand->get().dyn_cast<OpResult>();
1819       if (!def || !resultValue || !isScalarOrSplatConstantOp(def) ||
1820           !controlFn(resultValue, *opOperand))
1821         continue;
1822 
1823       // The operands and the indexing_maps of the fused operation the same as
1824       // the operands and indexing_maps of the generic operations with the
1825       // values at the constant index dropped.
1826       SmallVector<AffineMap> fusedIndexMaps;
1827       SmallVector<Value> fusedOperands;
1828       SmallVector<Location> fusedLocs{genericOp.getLoc()};
1829       fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1830       fusedOperands.reserve(genericOp.getNumInputs());
1831       fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
1832       for (OpOperand *inputOperand : genericOp.getInputOperands()) {
1833         if (inputOperand == opOperand)
1834           continue;
1835         Value inputValue = inputOperand->get();
1836         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1837         fusedOperands.push_back(inputValue);
1838         fusedLocs.push_back(inputValue.getLoc());
1839       }
1840       for (OpOperand *outputOperand : genericOp.getOutputOperands())
1841         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1842 
1843       // Check if the operation shapes to loops map is computable.
1844       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1845         return rewriter.notifyMatchFailure(
1846             genericOp, "fused op loop bound computation failed");
1847       }
1848 
1849       // Create a constant scalar value from the splat constant.
1850       Value scalarConstant = rewriter.create<arith::ConstantOp>(
1851           def->getLoc(), constantAttr, constantAttr.getType());
1852 
1853       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1854       auto fusedOp = rewriter.create<GenericOp>(
1855           rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1856           /*inputs=*/fusedOperands,
1857           /*outputs=*/outputOperands,
1858           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1859           genericOp.iterator_types(),
1860           /*doc=*/nullptr,
1861           /*library_call=*/nullptr);
1862 
1863       // Map the block argument corresponding to the replaced argument with the
1864       // scalar constant.
1865       Region &region = genericOp->getRegion(0);
1866       Block &entryBlock = *region.begin();
1867       BlockAndValueMapping mapping;
1868       mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1869                   scalarConstant);
1870       Region &fusedRegion = fusedOp->getRegion(0);
1871       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1872                                  mapping);
1873       rewriter.replaceOp(genericOp, fusedOp->getResults());
1874       return success();
1875     }
1876     return failure();
1877   }
1878 
1879 private:
1880   ControlElementwiseOpsFusionFn controlFn;
1881 };
1882 
1883 /// Base class for constant folding linalg.generic ops with N inputs, 1 output,
1884 /// and permutation indexing maps.
1885 ///
1886 /// `ConcreteType` should provide methods with signatures
1887 ///
1888 /// ```c++
1889 ///   bool matchIndexingMaps(GenericOp genericOp) const;
1890 ///   RegionComputationFn getRegionComputeFn(GenericOp) const;
1891 /// ```
1892 ///
1893 /// The latter inspects the region and returns the computation inside as a
1894 /// functor. The functor will be invoked with constant elements for all inputs
1895 /// and should return the corresponding computea constant element for output.
1896 template <typename ConcreteType>
1897 class FoldConstantBase : public OpRewritePattern<GenericOp> {
1898 public:
1899   struct APIntOrFloat {
1900     Optional<APInt> apInt;
1901     Optional<APFloat> apFloat;
1902   };
1903   struct APIntOrFloatArray {
1904     SmallVector<APInt> apInts;
1905     SmallVector<APFloat> apFloats;
1906   };
1907   using RegionComputationFn =
1908       std::function<APIntOrFloat(const APIntOrFloatArray &)>;
1909 
1910   FoldConstantBase(MLIRContext *context,
1911                    const ControlElementwiseOpsFusionFn &controlFn,
1912                    PatternBenefit benefit = 1)
1913       : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
1914 
1915   LogicalResult matchAndRewrite(GenericOp genericOp,
1916                                 PatternRewriter &rewriter) const override {
1917     if (genericOp.hasBufferSemantics())
1918       return failure();
1919 
1920     // Only support ops generating one output for now.
1921     if (genericOp.getNumOutputs() != 1)
1922       return failure();
1923 
1924     auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
1925     // Require the output types to be static give we are generating constants.
1926     if (!outputType || !outputType.hasStaticShape())
1927       return failure();
1928 
1929     if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
1930           return operand->get().getType().isa<ShapedType>();
1931         }))
1932       return failure();
1933 
1934     // Make sure all element types are the same.
1935     auto getOperandElementType = [](OpOperand *operand) {
1936       return operand->get().getType().cast<ShapedType>().getElementType();
1937     };
1938     if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
1939                                         getOperandElementType)))
1940       return failure();
1941 
1942     // We can only handle the case where we have int/float elements.
1943     auto elementType = outputType.getElementType();
1944     if (!elementType.isIntOrFloat())
1945       return failure();
1946 
1947     // Require all indexing maps to be permutations for now. This is common and
1948     // it simplifies input/output access greatly: we can do the data shuffling
1949     // entirely in the compiler, without needing to turn all indices into
1950     // Values, and then do affine apply on them, and then match back the
1951     // constant again.
1952     if (!llvm::all_of(genericOp.getIndexingMaps(),
1953                       [](AffineMap map) { return map.isPermutation(); }))
1954       return failure();
1955 
1956     for (OpOperand *operand : genericOp.getOutputOperands()) {
1957       if (genericOp.payloadUsesValueFromOperand(operand))
1958         return failure();
1959     }
1960 
1961     // Further check the indexing maps are okay for the ConcreteType.
1962     if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
1963       return failure();
1964 
1965     // Defer to the concrete type to check the region and discover the
1966     // computation inside.
1967     RegionComputationFn computeFn =
1968         static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
1969     if (!computeFn)
1970       return failure();
1971 
1972     // All inputs should be constants.
1973     int numInputs = genericOp.getNumInputs();
1974     SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
1975     for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) {
1976       if (!matchPattern(operand.value()->get(),
1977                         m_Constant(&inputValues[operand.index()])))
1978         return failure();
1979     }
1980 
1981     // Identified this as a potential candidate for folding. Now check the
1982     // policy to see whether we are allowed to proceed.
1983     for (int i = 0; i < numInputs; ++i) {
1984       OpOperand *consumer = genericOp.getInputOperand(i);
1985       OpResult producer = consumer->get().cast<OpResult>();
1986       if (!controlFn(producer, *consumer))
1987         return failure();
1988     }
1989 
1990     auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
1991     SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
1992     int64_t numElements = outputType.getNumElements();
1993 
1994     // Use APInt/APFloat instead of Attribute here for constructing the output.
1995     // This helps to avoid blowing up compiler memory usage: Attributes would
1996     // unify the following cases but they have lifetime as the MLIRContext.
1997     SmallVector<APInt> intOutputValues;
1998     SmallVector<APFloat> fpOutputValues;
1999     if (elementType.template isa<FloatType>())
2000       fpOutputValues.resize(numElements, APFloat(0.f));
2001     else
2002       intOutputValues.resize(numElements);
2003 
2004     // Return the constant dim positions from the given permutation map.
2005     auto getDimPositions = [](AffineMap map) {
2006       SmallVector<unsigned> dims;
2007       dims.reserve(map.getNumResults());
2008       for (AffineExpr result : map.getResults()) {
2009         dims.push_back(result.cast<AffineDimExpr>().getPosition());
2010       }
2011       return dims;
2012     };
2013 
2014     SmallVector<SmallVector<unsigned>> inputDims;
2015     for (int i = 0; i < numInputs; ++i)
2016       inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i]));
2017     auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
2018     auto outputShape = outputType.getShape();
2019 
2020     // Allocate small vectors for index delinearization. Initial values do not
2021     // matter here as they will be overwritten later.
2022     SmallVector<uint64_t> indices(loopBounds.size(), 0);
2023     SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
2024     SmallVector<SmallVector<uint64_t>> srcIndices(
2025         numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
2026     SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
2027     uint64_t dstLinearIndex = 0;
2028 
2029     // Allocate spaces for compute function inputs. Initial values do not matter
2030     // here as they will be overwritten later.
2031     APIntOrFloatArray computeFnInputs;
2032 
2033     auto inputShapes = llvm::to_vector<4>(
2034         llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
2035           return operand->get().getType().cast<ShapedType>().getShape();
2036         }));
2037 
2038     // Given a `linearIndex`, remap it to a linear index to access linalg op
2039     // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
2040     // `srcLinearIndices`, `dstLinearIndex` in place.
2041     auto computeRemappedLinearIndex = [&](int linearIndex) {
2042       int totalCount = linearIndex;
2043       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
2044         indices[dim] = totalCount % loopBounds[dim];
2045         totalCount /= loopBounds[dim];
2046       }
2047 
2048       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
2049         for (int i = 0; i < numInputs; ++i)
2050           srcIndices[i][dim] = indices[inputDims[i][dim]];
2051         dstIndices[dim] = indices[outputDims[dim]];
2052       }
2053 
2054       dstLinearIndex = dstIndices.front();
2055       for (int i = 0; i < numInputs; ++i)
2056         srcLinearIndices[i] = srcIndices[i].front();
2057 
2058       for (int dim = 1; dim < outputType.getRank(); ++dim) {
2059         dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
2060         for (int i = 0; i < numInputs; ++i)
2061           srcLinearIndices[i] =
2062               srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
2063       }
2064     };
2065 
2066     bool isFloat = elementType.isa<FloatType>();
2067     if (isFloat) {
2068       SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
2069       for (int i = 0; i < numInputs; ++i)
2070         inFpRanges.push_back(inputValues[i].getValues<APFloat>());
2071 
2072       computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
2073 
2074       // Transpose the input constant. Because we don't know its rank in
2075       // advance, we need to loop over the range [0, element count) and
2076       // delinearize the index.
2077       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
2078         computeRemappedLinearIndex(linearIndex);
2079 
2080         // Collect constant elements for all inputs at this loop iteration.
2081         for (int i = 0; i < numInputs; ++i)
2082           computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
2083 
2084         // Invoke the computation to get the corresponding constant output
2085         // element.
2086         fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
2087       }
2088     } else {
2089       SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
2090       for (int i = 0; i < numInputs; ++i)
2091         inIntRanges.push_back(inputValues[i].getValues<APInt>());
2092 
2093       computeFnInputs.apInts.resize(numInputs);
2094 
2095       // Transpose the input constant. Because we don't know its rank in
2096       // advance, we need to loop over the range [0, element count) and
2097       // delinearize the index.
2098       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
2099         computeRemappedLinearIndex(linearIndex);
2100 
2101         // Collect constant elements for all inputs at this loop iteration.
2102         for (int i = 0; i < numInputs; ++i)
2103           computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
2104 
2105         // Invoke the computation to get the corresponding constant output
2106         // element.
2107         intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
2108       }
2109     }
2110 
2111     DenseElementsAttr outputAttr =
2112         isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
2113                 : DenseElementsAttr::get(outputType, intOutputValues);
2114 
2115     rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
2116     return success();
2117   }
2118 
2119 private:
2120   ControlElementwiseOpsFusionFn controlFn;
2121 };
2122 
2123 // Folds linalg.generic ops that are actually transposes on constant values.
2124 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
2125   using FoldConstantBase::FoldConstantBase;
2126 
2127   bool matchIndexingMaps(GenericOp genericOp) const {
2128     // We should have one input and one output.
2129     return genericOp.getIndexingMaps().size() == 2;
2130   }
2131 
2132   RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
2133     // Make sure the region only contains a yield op.
2134     Block &body = genericOp.region().front();
2135     if (!llvm::hasSingleElement(body))
2136       return nullptr;
2137     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
2138     if (!yieldOp)
2139       return nullptr;
2140 
2141     // The yield op should return the block argument corresponds to the input.
2142     for (Value yieldVal : yieldOp.values()) {
2143       auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
2144       if (!yieldArg || yieldArg.getOwner() != &body)
2145         return nullptr;
2146       if (yieldArg.getArgNumber() != 0)
2147         return nullptr;
2148     }
2149 
2150     // No computation; just return the orginal value.
2151     return [](const APIntOrFloatArray &inputs) {
2152       if (inputs.apFloats.empty())
2153         return APIntOrFloat{inputs.apInts.front(), llvm::None};
2154       return APIntOrFloat{llvm::None, inputs.apFloats.front()};
2155     };
2156   }
2157 
2158   ControlElementwiseOpsFusionFn controlFn;
2159 };
2160 
2161 } // namespace
2162 
2163 //===---------------------------------------------------------------------===//
2164 // Miscellaneous patterns that help fusion.
2165 //===---------------------------------------------------------------------===//
2166 
2167 namespace {
2168 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
2169 /// the value of the `outs` operand is not used within the op.  This is only
2170 /// implemented for `linalg.generic` operations for now, but should hold for all
2171 /// linalg structured ops.
2172 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2173   using OpRewritePattern<GenericOp>::OpRewritePattern;
2174 
2175   LogicalResult matchAndRewrite(GenericOp op,
2176                                 PatternRewriter &rewriter) const override {
2177     rewriter.startRootUpdate(op);
2178     bool modifiedOutput = false;
2179     Location loc = op.getLoc();
2180     for (OpOperand *opOperand : op.getOutputOperands()) {
2181       if (!op.payloadUsesValueFromOperand(opOperand)) {
2182         Value operandVal = opOperand->get();
2183         auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
2184         if (!operandType)
2185           continue;
2186 
2187         // If outs is already an `init_tensor` operation, nothing to do.
2188         auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
2189         if (definingOp)
2190           continue;
2191         modifiedOutput = true;
2192         SmallVector<Value> dynamicDims;
2193         for (const auto &dim : llvm::enumerate(operandType.getShape())) {
2194           if (dim.value() != ShapedType::kDynamicSize)
2195             continue;
2196           dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
2197               loc, operandVal, dim.index()));
2198         }
2199         Value initTensor = rewriter.create<InitTensorOp>(
2200             loc, dynamicDims, operandType.getShape(),
2201             operandType.getElementType());
2202         op->setOperand(opOperand->getOperandNumber(), initTensor);
2203       }
2204     }
2205     if (!modifiedOutput) {
2206       rewriter.cancelRootUpdate(op);
2207       return failure();
2208     }
2209     rewriter.finalizeRootUpdate(op);
2210     return success();
2211   }
2212 };
2213 } // namespace
2214 
2215 //===---------------------------------------------------------------------===//
2216 // Methods that add patterns descrined in this file to a pattern list.
2217 //===---------------------------------------------------------------------===//
2218 
2219 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
2220     RewritePatternSet &patterns) {
2221   patterns
2222       .add<FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
2223            FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
2224            FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
2225            FoldConsumerReshapeOpByLinearization<false, tensor::ExpandShapeOp>>(
2226           patterns.getContext());
2227 }
2228 
2229 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
2230     RewritePatternSet &patterns) {
2231   patterns
2232       .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
2233            FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>,
2234            FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
2235            FoldConsumerReshapeOpByLinearization<true, tensor::ExpandShapeOp>>(
2236           patterns.getContext());
2237 }
2238 
2239 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
2240     RewritePatternSet &patterns,
2241     const ControlElementwiseOpsFusionFn &controlFoldingReshapes) {
2242   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2243                                                     controlFoldingReshapes);
2244   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2245                                                      controlFoldingReshapes);
2246 }
2247 
2248 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
2249     RewritePatternSet &patterns,
2250     const ControlElementwiseOpsFusionFn &controlFoldingReshapes) {
2251   patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2252                                                       controlFoldingReshapes);
2253 }
2254 
2255 void mlir::linalg::populateElementwiseOpsFusionPatterns(
2256     RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
2257   auto *context = patterns.getContext();
2258   patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
2259                FoldConstantTranspose>(context,
2260                                       options.controlElementwiseOpsFusionFn);
2261   patterns.add<RemoveOutsDependency>(context);
2262   populateFoldReshapeOpsByExpansionPatterns(patterns,
2263                                             options.controlFoldingReshapesFn);
2264   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2265   GenericOp::getCanonicalizationPatterns(patterns, context);
2266   tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2267   tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2268   context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2269       patterns);
2270 }
2271 
2272 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
2273   auto *context = patterns.getContext();
2274   patterns.add<PushExpandingReshape>(context);
2275 }
2276 
2277 //===---------------------------------------------------------------------===//
2278 // Passes
2279 //===---------------------------------------------------------------------===//
2280 
2281 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
2282                                       OpOperand &consumer) {
2283   if (auto producerCollapseOp =
2284           dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
2285     return !isUnitDimExpansionOnly(producerCollapseOp);
2286   }
2287   if (auto consumerExpandOp =
2288           dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
2289     return !isUnitDimExpansionOnly(consumerExpandOp);
2290   }
2291   return true;
2292 }
2293 
2294 namespace {
2295 
2296 /// Pass that fuses generic ops on tensors. Used only for testing.
2297 struct LinalgElementwiseOpFusionPass
2298     : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
2299   void runOnOperation() override {
2300     Operation *op = getOperation();
2301     RewritePatternSet patterns(op->getContext());
2302     ControlElementwiseOpsFusionFn allowFoldingFn =
2303         [](const OpResult &producer, const OpOperand &consumer) {
2304           return true;
2305         };
2306     populateElementwiseOpsFusionPatterns(
2307         patterns,
2308         LinalgElementwiseFusionOptions().setControlFoldingReshapes(
2309             allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
2310 
2311     // Use TopDownTraversal for compile time reasons
2312     GreedyRewriteConfig grc;
2313     grc.useTopDownTraversal = true;
2314     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
2315                                        grc);
2316   }
2317 };
2318 
2319 /// Pass to test folding of reshape ops with generic ops by linearization.
2320 struct FoldReshapeOpsByLinearizationPass
2321     : public LinalgFoldReshapeOpsByLinearizationBase<
2322           FoldReshapeOpsByLinearizationPass> {
2323   void runOnOperation() override {
2324     Operation *op = getOperation();
2325     RewritePatternSet patterns(op->getContext());
2326     populateFoldReshapeOpsByLinearizationPatterns(patterns);
2327     if (allowFoldingUnitDimReshapes) {
2328       populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
2329     }
2330     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
2331   }
2332 };
2333 
2334 } // namespace
2335 
2336 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
2337   return std::make_unique<LinalgElementwiseOpFusionPass>();
2338 }
2339 
2340 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
2341   return std::make_unique<FoldReshapeOpsByLinearizationPass>();
2342 }
2343