1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include <utility>
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 //===---------------------------------------------------------------------===//
32 // Methods and patterns that fuse elementwise `linalg.generic` operations.
33 //===---------------------------------------------------------------------===//
34 
35 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
36 /// the `producer` to use in the fused operation given the indexing map of the
37 /// result of the producer in the consumer.
38 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
39     OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
40     AffineMap fusedConsumerArgIndexMap) {
41   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
42   // from consumer loop -> consumer arg tensor index/producer result tensor
43   // index. The fused loop is same as the consumer loop. For each producer arg
44   // the indexing map to be computed is a map from consumer loop -> producer
45   // arg tensor index.
46   // producerResultIndexMap is a map from producer loop -> tensor index.
47   // Compute the inverse to get map from tensor index -> producer loop.
48   // The inverse is a map from producer result tensor index -> producer loop.
49   AffineMap invProducerResultIndexMap =
50       inversePermutation(producerResultIndexMap);
51   assert(invProducerResultIndexMap &&
52          "expected producer result indexing map to be invertible");
53 
54   LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
55   // argMap is a map from producer loop -> producer arg tensor index.
56   AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
57 
58   // Compose argMap with invProducerResultIndexMap to get a map from
59   // producer result tensor index -> producer arg tensor index.
60   AffineMap t1 = argMap.compose(invProducerResultIndexMap);
61 
62   // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
63   // consumer loop/ fused loop -> producer arg tensor index.
64   return t1.compose(fusedConsumerArgIndexMap);
65 }
66 
67 /// Conditions for elementwise fusion of generic operations.
68 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
69                                      OpOperand *consumerOpOperand) {
70   // Producer and consumer must have tensor semantics.
71   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
72     return false;
73 
74   // Verify that
75   // - the producer has all "parallel" iterator type.
76   if (producer.getNumParallelLoops() != producer.getNumLoops())
77     return false;
78 
79   // Only allow fusing the producer of an input operand for now.
80   // TODO: allow fusing the producer of an output operand.
81   if (!consumer.isInputTensor(consumerOpOperand))
82     return false;
83 
84   // Get the consumer index map. The number of results of the consumer index
85   // map must match the number of loops of the producer.
86   AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
87   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
88     return false;
89 
90   // Currently support only operations with single result.
91   if (producer.getNumOutputs() != 1)
92     return false;
93 
94   // Finally the index_map for the result must be invertible. For now just
95   // verify it is a permutation.
96   AffineMap producerResultIndexMap =
97       producer.getTiedIndexingMap(producer.getOutputOperand(0));
98   if (!producerResultIndexMap.isPermutation())
99     return false;
100 
101   // Ensure that the fusion does not remove size information required to
102   // get the loop bounds. For non-reduction generics, this is trivially the
103   // case due to the output operand. For reductions, we need to check that after
104   // the fusion, each loop dimension has at least one input that defines it.
105   if ((consumer.getNumReductionLoops())) {
106     BitVector coveredDims(consumer.getNumLoops(), false);
107 
108     auto addToCoveredDims = [&](AffineMap map) {
109       for (auto result : map.getResults())
110         if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
111           coveredDims[dimExpr.getPosition()] = true;
112     };
113 
114     for (auto pair :
115          llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) {
116       Value operand = std::get<0>(pair);
117       if (operand == consumerOpOperand->get())
118         continue;
119       AffineMap operandMap = std::get<1>(pair);
120       addToCoveredDims(operandMap);
121     }
122 
123     for (OpOperand *operand : producer.getInputOperands()) {
124       AffineMap newIndexingMap =
125           getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
126               operand, producerResultIndexMap, consumerIndexMap);
127       addToCoveredDims(newIndexingMap);
128     }
129     if (!coveredDims.all())
130       return false;
131   }
132 
133   return true;
134 }
135 
136 /// Generate the region of the fused tensor operation. The region of the fused
137 /// op must be empty.
138 static void
139 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
140                                  AffineMap consumerToProducerLoopsMap,
141                                  OpOperand *consumerOpOperand,
142                                  unsigned nloops) {
143   auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
144   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
145   // Build the region of the fused op.
146   Block &producerBlock = producer->getRegion(0).front();
147   Block &consumerBlock = consumer->getRegion(0).front();
148   Block *fusedBlock = new Block();
149   fusedOp.region().push_back(fusedBlock);
150   BlockAndValueMapping mapper;
151   OpBuilder::InsertionGuard guard(rewriter);
152   rewriter.setInsertionPointToStart(fusedBlock);
153 
154   // 2. Add an index operation for every fused loop dimension and use the
155   // `consumerToProducerLoopsMap` to map the producer indices.
156   if (producer.hasIndexSemantics()) {
157     // Add an index operation for every fused loop dimension.
158     unsigned numFusedOpLoops =
159         std::max(producer.getNumLoops(), consumer.getNumLoops());
160     SmallVector<Value> fusedIndices;
161     fusedIndices.reserve(numFusedOpLoops);
162     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
163                     std::back_inserter(fusedIndices), [&](uint64_t dim) {
164                       return rewriter.create<IndexOp>(producer.getLoc(), dim);
165                     });
166     for (IndexOp indexOp :
167          llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
168       Value newIndex = rewriter.create<mlir::AffineApplyOp>(
169           producer.getLoc(),
170           consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
171       mapper.map(indexOp.getResult(), newIndex);
172     }
173   }
174   // TODO: allow fusing the producer of an output operand.
175   assert(consumer.isInputTensor(consumerOpOperand) &&
176          "expected producer of input operand");
177   // 3. Consumer input operands up to consumerIdx (exclusive).
178   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
179            consumerOpOperand->getOperandNumber())) // input assumption.
180     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
181 
182   // Replacing consumerIdx requires getting the cloned, yielded, value from
183   // the (cloned) producer block. This happens in step 9.
184 
185   // 4. Splice in producer's input operands.
186   for (BlockArgument bbArg :
187        producerBlock.getArguments().take_front(producer.getNumInputs()))
188     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
189 
190   // 4.b. Producer output operand/map that is fused needs to be mapped to the
191   // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
192   assert(producer->getNumResults() == 1 && "expected single result producer");
193   if (producer.isInitTensor(producer.getOutputOperand(0))) {
194     BlockArgument bbArg = producerBlock.getArguments()
195                               .drop_front(producer.getNumInputs())
196                               // TODO: bbArg index of
197                               .front();
198     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
199   }
200   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
201   for (BlockArgument bbArg :
202        consumerBlock.getArguments()
203            .take_front(consumer.getNumInputs())
204            .drop_front(consumerOpOperand->getOperandNumber() + 1))
205     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
206   // 6. All of consumer's output operands.
207   for (BlockArgument bbArg :
208        consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
209     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
210   // 7. All of producer's output operands except the one fused.
211   // TODO: allow fusion of multi-result producers.
212   assert(producer->getNumResults() == 1 && "expected single result producer");
213 
214   // 8. Clone all producer operations except for the yield and index operations
215   // to the fused operation.
216   for (auto &op : producerBlock.without_terminator()) {
217     if (!isa<IndexOp>(op))
218       rewriter.clone(op, mapper);
219   }
220   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
221   // forward the yield operand.
222   auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
223   // TODO: allow fusion of multi-result producers.
224   assert(producer->getNumResults() == 1 && "expected single result producer");
225   unsigned producerResultNumber = 0;
226   Value replacement =
227       mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
228   // Sanity checks, if replacement is not already in the mapper then it must be
229   // produced outside.
230   if (replacement == yieldOp.getOperand(producerResultNumber)) {
231     if (auto bb = replacement.dyn_cast<BlockArgument>())
232       assert(bb.getOwner() != &producerBlock &&
233              "yielded block argument must have been mapped");
234     else
235       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
236              "yielded value must have been mapped");
237   }
238   mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
239              replacement);
240   // 10. Clone operations from the consumer to the fused op.
241   for (auto &op : consumerBlock.getOperations())
242     rewriter.clone(op, mapper);
243 
244   // Sanity checks.
245   assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
246          "Ill-formed GenericOp region");
247 }
248 
249 static Optional<SmallVector<Value>>
250 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
251                        const ControlElementwiseOpsFusionFn &controlFn,
252                        PatternRewriter &rewriter) {
253   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
254   if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
255       !controlFn(producer->getResult(0), *consumerOpOperand))
256     return llvm::None;
257 
258   // TODO: allow fusing the producer of an output operand.
259   assert(consumer.isInputTensor(consumerOpOperand) &&
260          "expected producer of input operand");
261 
262   // Compute the fused operands list and indexing maps.
263   SmallVector<Value> fusedOperands;
264   SmallVector<AffineMap> fusedIndexMaps;
265   fusedOperands.reserve(producer->getNumOperands() +
266                         consumer->getNumOperands());
267   fusedIndexMaps.reserve(producer->getNumOperands() +
268                          consumer->getNumOperands());
269   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
270   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
271   SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
272   SmallVector<OpOperand *>::iterator it =
273       llvm::find(consumerInputs, consumerOpOperand);
274   assert(it != consumerInputs.end() && "expected to find the consumer operand");
275   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
276     fusedOperands.push_back(opOperand->get());
277     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
278   }
279   // 4. Splice in producer's input operands/maps.
280   assert(producer->getNumResults() == 1 && "expected single result producer");
281   AffineMap producerResultIndexMap =
282       producer.getTiedIndexingMap(producer.getOutputOperand(0));
283   for (OpOperand *opOperand : producer.getInputOperands()) {
284     fusedOperands.push_back(opOperand->get());
285     // Compute indexing maps for the producer args in the fused operation.
286     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
287         opOperand, producerResultIndexMap,
288         consumer.getTiedIndexingMap(consumerOpOperand));
289     fusedIndexMaps.push_back(map);
290   }
291   // 4.b. Producer output operand/map that is fused needs to be passed if it is
292   // an "initTensor" (i.e. its value is actually read).
293   assert(producer->getNumResults() == 1 && "expected single result producer");
294   if (producer.isInitTensor(producer.getOutputOperand(0))) {
295     fusedOperands.push_back(producer.getOutputOperand(0)->get());
296     // Compute indexing maps for the producer args in the fused operation.
297     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
298         producer.getOutputOperand(0), producerResultIndexMap,
299         consumer.getTiedIndexingMap(consumerOpOperand));
300     fusedIndexMaps.push_back(map);
301   }
302   // 5. Remaining consumer's input operands/maps (drop past index
303   // `consumerIdx`).
304   for (OpOperand *opOperand :
305        llvm::make_range(std::next(it), consumerInputs.end())) {
306     fusedOperands.push_back(opOperand->get());
307     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
308   }
309   // 6. All of consumer's output operands (skip operands: added by the builder).
310   for (OpOperand *opOperand : consumer.getOutputOperands())
311     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
312   // 7. All of producer's output operands/maps except the one fused.
313   // TODO: allow fusion of multi-result producers.
314   assert(producer->getNumResults() == 1 && "expected single result producer");
315 
316   // Generate the fused op.
317   SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
318   auto fusedOp = rewriter.create<GenericOp>(
319       consumer.getLoc(), consumer->getResultTypes(),
320       /*inputs=*/fusedOperands,
321       // TODO: handle outputs.
322       consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
323       consumer.iterator_types(),
324       /*doc=*/nullptr,
325       /*library_call=*/nullptr);
326   if (!fusedOp.getShapesToLoopsMap()) {
327     // Fused op has invalid indexing maps. Typically this means something is off
328     // in the input, but going ahead here would result in verification errors.
329     // So cleanup and abort.
330     rewriter.eraseOp(fusedOp);
331     return llvm::None;
332   }
333 
334   // Construct an AffineMap from consumer loops to producer loops.
335   // consumer loop -> tensor index
336   AffineMap consumerResultIndexMap =
337       consumer.getTiedIndexingMap(consumerOpOperand);
338   // tensor index -> producer loop
339   AffineMap invProducerResultIndexMap =
340       inversePermutation(producerResultIndexMap);
341   assert(invProducerResultIndexMap &&
342          "expected producer result indexig map to be invertible");
343   // consumer loop -> producer loop
344   AffineMap consumerToProducerLoopsMap =
345       invProducerResultIndexMap.compose(consumerResultIndexMap);
346 
347   generateFusedElementwiseOpRegion(rewriter, fusedOp,
348                                    consumerToProducerLoopsMap,
349                                    consumerOpOperand, consumer.getNumLoops());
350   return SmallVector<Value>(fusedOp->getResults());
351 }
352 
353 static Optional<SmallVector<Value>>
354 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
355                    GenericOp producer,
356                    const ControlElementwiseOpsFusionFn &controlFn) {
357   if (producer->getNumResults() != 1)
358     return llvm::None;
359 
360   return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
361                                 rewriter);
362 }
363 
364 namespace {
365 /// Patterns to fuse a generic op, with the producer of its operands.
366 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
367 public:
368   FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
369                      PatternBenefit benefit = 1)
370       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
371 
372   LogicalResult matchAndRewrite(GenericOp genericOp,
373                                 PatternRewriter &rewriter) const override {
374     // Find the first operand that is defined by another generic op on tensors.
375     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
376       auto producer =
377           dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
378       if (!producer || !producer.hasTensorSemantics())
379         continue;
380       Optional<SmallVector<Value>> fusedOpResults =
381           fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
382       if (fusedOpResults) {
383         rewriter.replaceOp(genericOp, *fusedOpResults);
384         return success();
385       }
386     }
387     return failure();
388   }
389 
390 private:
391   ControlElementwiseOpsFusionFn controlFn;
392 };
393 } // namespace
394 
395 //===---------------------------------------------------------------------===//
396 // Methods and patterns that fuse reshape ops with elementwise operations by
397 // linearization of indexing maps.
398 //===---------------------------------------------------------------------===//
399 
400 // TODO(ravishankarm): The indexing maps
401 // these produce in the general case are detrimental to transformations.
402 // These patterns are on deprecation path in favor of using fusion by
403 // collapsing, which covers the only legitimate use case of this pattern of
404 // folding unit-extent dims.
405 
406 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
407 /// provided, given the shape of the source tensor that corresponds to the
408 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
409 /// are "row-major" ordered logically.
410 ///
411 /// For example:
412 ///
413 /// %0 = op ... : tensor<?x?x4x5xf32>
414 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
415 ///
416 /// and reshape:
417 /// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] :
418 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
419 ///
420 /// would be rewritten into:
421 /// %0 = op ... : tensor<?x?x4x5xf32>
422 /// with output index_map
423 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
424 template <typename TensorReshapeOp>
425 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
426                                         TensorReshapeOp reshapeOp) {
427   constexpr bool isExpanding =
428       std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
429   ArrayRef<int64_t> sourceShape =
430       (isExpanding ? reshapeOp.getResultType().getShape()
431                    : reshapeOp.getSrcType().getShape());
432   SmallVector<AffineExpr> resultExprs;
433   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
434   MLIRContext *context = sourceMap.getContext();
435 
436   // Compute the result exprs based on the reassociation maps.
437   for (auto &indices : reshapeOp.getReassociationIndices()) {
438     // Assume that they are in-order and contiguous (already checked in
439     // verifier).
440     assert(!indices.empty());
441     SmallVector<int64_t> sizes;
442     SmallVector<AffineExpr> dimExprs;
443     for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
444                              sourceExprs.slice(indices[0], indices.size()))) {
445       if (std::get<0>(en) == 1)
446         continue;
447       sizes.push_back(std::get<0>(en));
448       dimExprs.push_back(std::get<1>(en));
449     }
450     AffineExpr linearizedExpr =
451         makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
452     resultExprs.push_back(linearizedExpr);
453   }
454   // The new affine map cannot drop unused dimension but some new symbols may
455   // have been added. Create a map with at least as many dimensions/symbols as
456   // the original affine map.
457   int64_t maxDim = -1;
458   int64_t maxSym = -1;
459   getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym);
460   unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims());
461   unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols());
462   return AffineMap::get(numDims, numSyms, resultExprs, context);
463 }
464 
465 // tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a
466 // producer). Fusing when operand has higher rank will require use of mods and
467 // divs in the indexing maps of the fused op which would make it non-invertible.
468 static bool isTensorReshapeOpFoldableByLinearization(
469     tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
470   if (!asProducer)
471     return false;
472   return useIndexMap.isPermutation();
473 }
474 
475 // tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a
476 // consumer).
477 static bool
478 isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp,
479                                          AffineMap useIndexMap,
480                                          bool asProducer) {
481   if (asProducer)
482     return false;
483   return useIndexMap.isPermutation();
484 }
485 
486 /// Check if the reshape operation is only expansion into/collapsing of
487 /// unit-dimension.
488 template <typename TensorReshapeOp>
489 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
490   constexpr bool isExpanding =
491       std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
492   ArrayRef<int64_t> expandedShape =
493       (isExpanding ? reshapeOp.getResultType().getShape()
494                    : reshapeOp.getSrcType().getShape());
495   for (auto &indices : reshapeOp.getReassociationIndices()) {
496     unsigned numUnitDims = 0;
497     for (int64_t position : indices)
498       if (expandedShape[position] == 1)
499         numUnitDims++;
500     if (numUnitDims != indices.size() - 1)
501       return false;
502   }
503   return true;
504 }
505 
506 namespace {
507 /// Pattern to fold tensor_expand_shape op with its consumer by using the source
508 /// of the reshape op as the operand in the consumer (instead of the result of
509 /// the tensor_collapse_shape). The corresponding index map in the consumer
510 /// needs to be modified to linearize the folded dimension.
511 ///
512 /// For example,
513 ///
514 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
515 /// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
516 ///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
517 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
518 ///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
519 ///        -> tensor<?x?x4x?xf32>
520 ///
521 /// can be folded into
522 ///
523 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
524 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
525 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
526 ///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
527 ///        -> tensor<?x?x4x?xf32>
528 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
529 struct FoldProducerReshapeOpByLinearization
530     : public OpRewritePattern<GenericOp> {
531   using OpRewritePattern<GenericOp>::OpRewritePattern;
532 
533   LogicalResult matchAndRewrite(GenericOp genericOp,
534                                 PatternRewriter &rewriter) const override {
535     if (!genericOp.hasTensorSemantics())
536       return failure();
537     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
538     for (const auto &en : llvm::enumerate(inputOperands)) {
539       auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
540       if (!reshapeOp)
541         continue;
542 
543       if (!isTensorReshapeOpFoldableByLinearization(
544               reshapeOp, genericOp.getTiedIndexingMap(en.value()),
545               /*asProducer =*/true) ||
546           (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
547         continue;
548 
549       // Compute the fused operands list,
550       SmallVector<Value> fusedOperands = genericOp.getInputOperands();
551       fusedOperands[en.index()] = reshapeOp.src();
552       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
553       llvm::append_range(fusedOperands, outputOperands);
554 
555       // Compute indexing_maps for the fused operation. The indexing_maps for
556       // the operands of the consumers that arent fused are the same.
557       SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
558 
559       // Compute the indexing map to use for the result of the producer.
560       AffineMap modifiedMap =
561           linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
562       // The modified map cannot have symbols.
563       if (modifiedMap.getNumSymbols())
564         return failure();
565       for (AffineExpr expr : modifiedMap.getResults()) {
566         if (!expr.isPureAffine())
567           return failure();
568       }
569       fusedIndexMaps[en.index()] = modifiedMap;
570 
571       // Further check that the resulting index maps can be fused and
572       // inverted. Without this the resultant op is not legal.
573       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
574         return rewriter.notifyMatchFailure(
575             genericOp, "fused op loop bound computation failed");
576       }
577 
578       rewriter.startRootUpdate(genericOp);
579       genericOp->setOperands(fusedOperands);
580       genericOp.indexing_mapsAttr(
581           rewriter.getAffineMapArrayAttr(fusedIndexMaps));
582       rewriter.finalizeRootUpdate(genericOp);
583       return success();
584     }
585     return failure();
586   }
587 };
588 
589 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
590 /// producer. The corresponding index map in the consumer needs to be modified
591 /// to linearize the folded dimension.
592 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
593 struct FoldConsumerReshapeOpByLinearization
594     : public OpRewritePattern<TensorReshapeOp> {
595   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
596 
597   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
598                                 PatternRewriter &rewriter) const override {
599     GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
600     if (!producer || !producer.hasTensorSemantics() ||
601         producer.getNumOutputs() != 1 ||
602         !isTensorReshapeOpFoldableByLinearization(
603             reshapeOp,
604             producer.getTiedIndexingMap(producer.getOutputOperand(0)),
605             /*asProducer =*/false) ||
606         (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
607       return failure();
608     // The indexing_maps for the operands of the fused operation are same as
609     // those for the operands of the producer.
610     SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
611 
612     // Compute the indexing map to use for the operand of the producer.
613     AffineMap modifiedMap = linearizeCollapsedDims(
614         producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
615     for (AffineExpr expr : modifiedMap.getResults()) {
616       if (!expr.isPureAffine()) {
617         return rewriter.notifyMatchFailure(
618             producer, "fused op indexing map is not affine");
619       }
620     }
621     fusedIndexMaps.back() = modifiedMap;
622 
623     // Further check that the resulting index maps can be fused and
624     // inverted. Without this the resultant op is not legal.
625     if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
626       return rewriter.notifyMatchFailure(
627           producer, "fused op loop bound computation failed");
628     }
629 
630     Location loc = producer.getLoc();
631     SmallVector<Value> inputOperands = producer.getInputOperands();
632     Value output = rewriter.create<TensorReshapeOp>(
633         loc, producer.getOutputOperand(0)->get(),
634         reshapeOp.getReassociationExprs());
635     auto fusedOp = rewriter.create<GenericOp>(
636         loc, reshapeOp.getResultType(),
637         /*inputs=*/inputOperands,
638         // TODO: handle outputs.
639         /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
640         producer.iterator_types(),
641         /*doc=*/nullptr,
642         /*library_call=*/nullptr);
643     auto &fusedRegion = fusedOp->getRegion(0);
644     rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
645                                fusedRegion.begin());
646     rewriter.replaceOp(reshapeOp, fusedOp->getResults());
647     return success();
648   }
649 };
650 } // namespace
651 
652 //===---------------------------------------------------------------------===//
653 // Methods and patterns that fuse reshape ops with elementwise operations by
654 // expanding the dimensionality of the elementwise operations.
655 //===---------------------------------------------------------------------===//
656 
657 /// Conditions for folding a generic operation with a reshape op by expanding
658 /// the iteration space dimensionality for tensor operations. These are
659 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
660 /// following fusion pattern.
661 ///
662 ///  Consider
663 ///
664 ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
665 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
666 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
667 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
668 ///  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
669 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
670 ///
671 ///  The reshape can be folded into the `genericOp` if its loop dimensionality
672 ///  is increased to match the result (operand) of the tensor_expand_shape.
673 ///  The indexing_map of the fused tensor in the `genericOp` and the
674 ///  reassociation map helps compute the indexing maps of the modified op.
675 ///  For the above example, based on the reassociation map it
676 ///  can be concluded that
677 ///
678 ///  - The loop used to access the first dimension of the fused tensor is split
679 ///    into two.
680 ///  - The loop used to access the second dimension of the fused tensor is kept
681 ///    as is.
682 ///  - The loop used to access the third dimension of the fused tensor is split
683 ///    into three.
684 ///
685 ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
686 ///  op, then
687 ///
688 ///   d0 -> e0, e1
689 ///   d1 -> e2, e3, e4
690 ///   d2 -> e5
691 ///
692 ///  substituting this, the generic op can be rewritten as
693 ///
694 ///  %d = linalg.generic ins(%0, %1 : )
695 ///        indexing_maps =
696 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
697 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
698 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
699 ///
700 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
701 ///  to make it consistent
702 ///
703 ///  %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
704 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
705 ///  %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
706 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
707 ///
708 ///  The added reshapes are again expanding patterns, so they will get fused
709 ///  with its producers if possible.
710 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
711                                                OpOperand *fusableOpOperand) {
712   // Is fusable only if:
713   // - All the indexing maps for operands and results are projected
714   //   permutations.
715   // - The fused tensor is not a scalar.
716   // - All the loops are parallel loops.
717   return genericOp.hasTensorSemantics() &&
718          llvm::all_of(genericOp.indexing_maps().getValue(),
719                       [](Attribute attr) {
720                         return attr.cast<AffineMapAttr>()
721                             .getValue()
722                             .isProjectedPermutation();
723                       }) &&
724          genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
725          llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
726            return attr.cast<StringAttr>().getValue() ==
727                   getParallelIteratorTypeName();
728          });
729 }
730 
731 namespace {
732 /// Information needed to expand a generic operation to fold the reshape with
733 /// it.
734 class ExpansionInfo {
735 public:
736   // Computes the mapping from original dimensions of the op to the dimensions
737   // of the expanded op given the `indexingMap` of the fused operand/result of
738   // the generic op, the `reassocationMaps` of the reshape op and the shape of
739   // the expanded op.
740   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
741                         ArrayRef<AffineMap> reassociationMaps,
742                         ArrayRef<int64_t> expandedShape,
743                         ArrayRef<int64_t> collapsedShape,
744                         PatternRewriter &rewriter);
745   unsigned getOrigOpNumDims() const { return reassociation.size(); }
746   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
747   ReassociationIndicesRef getExpandedDims(unsigned i) const {
748     return reassociation[i];
749   }
750   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
751     return expandedShapeMap[i];
752   }
753   ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
754 
755 private:
756   /// Reassociation from the dimensions in the original operation to the
757   /// dimension of the expanded operation.
758   SmallVector<ReassociationIndices> reassociation;
759   /// Mapping from extent of loops in the original operation, to the extent of
760   /// loops in the expanded operation.
761   SmallVector<SmallVector<int64_t>> expandedShapeMap;
762   /// Extent of the loop in the original operation.
763   SmallVector<int64_t> originalLoopExtent;
764   unsigned expandedOpNumDims;
765 };
766 } // namespace
767 
768 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
769                                      OpOperand *fusableOpOperand,
770                                      ArrayRef<AffineMap> reassociationMaps,
771                                      ArrayRef<int64_t> expandedShape,
772                                      ArrayRef<int64_t> collapsedShape,
773                                      PatternRewriter &rewriter) {
774   if (reassociationMaps.empty())
775     return failure();
776   AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
777 
778   Optional<SmallVector<int64_t, 4>> originalLoopRange =
779       linalgOp.getStaticLoopRanges();
780   if (!originalLoopRange)
781     return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
782   originalLoopExtent.assign(originalLoopRange->begin(),
783                             originalLoopRange->end());
784 
785   reassociation.clear();
786   expandedShapeMap.clear();
787   // Compute the number of dimension in the expanded op that correspond to each
788   // dimension of the original op.
789   SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
790   expandedShapeMap.resize(fusedIndexMap.getNumDims());
791   for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
792     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
793     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
794     numExpandedDims[pos] = foldedDims.getNumResults();
795     ArrayRef<int64_t> shape =
796         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
797     expandedShapeMap[pos].assign(shape.begin(), shape.end());
798   }
799   // The remaining dimensions remain the same.
800   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
801     if (expandedShapeMap[i].empty())
802       expandedShapeMap[i] = {originalLoopExtent[i]};
803 
804   // Compute reassociation map from the original op to the expanded op.
805   unsigned sum = 0;
806   reassociation.reserve(fusedIndexMap.getNumDims());
807   for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
808     auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
809     reassociation.emplace_back(seq.begin(), seq.end());
810     sum += numFoldedDim.value();
811   }
812   expandedOpNumDims = sum;
813   return success();
814 }
815 
816 /// Epanding the body of a linalg operation requires adaptations of the accessed
817 /// loop indices. Specifically, access of indices in the original operation need
818 /// to be replaced with linearizations of indices in the expanded op. That
819 /// requires the shape of the expanded dimensions to be static (at least all but
820 /// the most significant). For now check that these are all statically sized.
821 /// Note that this could be extended to handle dynamic case, but the
822 /// implementation below uses `affine.apply` which seems to have issues when the
823 /// shapes are not static.
824 static LogicalResult isGenericOpExpandable(GenericOp genericOp,
825                                            const ExpansionInfo &expansionInfo,
826                                            PatternRewriter &rewriter) {
827   if (!genericOp.hasIndexSemantics())
828     return success();
829   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
830     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
831     if (expandedShape.size() == 1)
832       continue;
833     for (int64_t shape : expandedShape.drop_front()) {
834       if (ShapedType::isDynamic(shape)) {
835         return rewriter.notifyMatchFailure(
836             genericOp, "cannot expand due to index semantics and dynamic dims");
837       }
838     }
839   }
840   return success();
841 }
842 
843 /// Return the indexing map to use in the expanded op for a given the
844 /// `indexingMap` of the original operation.
845 static AffineMap
846 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
847                            const ExpansionInfo &expansionInfo) {
848   SmallVector<AffineExpr> newExprs;
849   for (AffineExpr expr : indexingMap.getResults()) {
850     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
851     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
852         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
853           return builder.getAffineDimExpr(static_cast<unsigned>(v));
854         }));
855     newExprs.append(expandedExprs.begin(), expandedExprs.end());
856   }
857   return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
858                         indexingMap.getNumSymbols(), newExprs,
859                         builder.getContext());
860 }
861 
862 /// Return the type of the operand/result to use in the expanded op given the
863 /// type in the original op.
864 static RankedTensorType getExpandedType(RankedTensorType originalType,
865                                         AffineMap indexingMap,
866                                         const ExpansionInfo &expansionInfo) {
867   SmallVector<int64_t> expandedShape;
868   for (AffineExpr expr : indexingMap.getResults()) {
869     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
870     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
871     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
872   }
873   return RankedTensorType::get(expandedShape, originalType.getElementType());
874 }
875 
876 /// Returns the reassociation maps to use in the `tensor.expand_shape`
877 /// operation to convert the operands of the original operation to operands of
878 /// the expanded operation. The same method is used to compute the
879 /// `tensor.collapse_shape` used to collapse the result of the expanded
880 /// op to get the value that can replace all uses of the results of the original
881 /// op.
882 static SmallVector<ReassociationIndices>
883 getReassociationForExpansion(AffineMap indexingMap,
884                              const ExpansionInfo &expansionInfo) {
885   SmallVector<ReassociationIndices> reassociation;
886   unsigned numReshapeDims = 0;
887   for (AffineExpr expr : indexingMap.getResults()) {
888     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
889     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
890     SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
891         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
892     reassociation.emplace_back(std::move(indices));
893     numReshapeDims += numExpandedDims;
894   }
895   return reassociation;
896 }
897 
898 /// Update the body of an expanded linalg operation having index semantics. The
899 /// indices of the original operation need to be recovered by linearizing the
900 /// indices of the correspoding dimensions of the expanded operation. For now it
901 /// is assumed that the shapes of the expanded operation needed for
902 /// linearization are static.
903 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
904                                           Location loc, Region &fusedRegion,
905                                           const ExpansionInfo &expansionInfo) {
906   // Replace the original indices by the linearization of the expanded indices.
907   for (IndexOp indexOp :
908        llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
909     ArrayRef<int64_t> expandedDims =
910         expansionInfo.getExpandedDims(indexOp.dim());
911     assert(!expandedDims.empty() && "expected valid expansion info");
912 
913     // Skip index operations that are not affected by the expansion.
914     if (expandedDims.size() == 1 &&
915         expandedDims.front() == (int64_t)indexOp.dim())
916       continue;
917 
918     // Linearize the expanded indices of the original index dimension.
919     OpBuilder::InsertionGuard guard(rewriter);
920     rewriter.setInsertionPointAfter(indexOp);
921     ArrayRef<int64_t> expandedDimsShape =
922         expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
923     SmallVector<Value> expandedIndices;
924     expandedIndices.reserve(expandedDims.size() - 1);
925     llvm::transform(
926         expandedDims.drop_front(), std::back_inserter(expandedIndices),
927         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
928     Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
929     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
930       assert(!ShapedType::isDynamic(std::get<0>(it)));
931       AffineExpr idx, acc;
932       bindDims(rewriter.getContext(), idx, acc);
933       newIndex = rewriter.create<AffineApplyOp>(
934           indexOp.getLoc(), idx + acc * std::get<0>(it),
935           ValueRange{std::get<1>(it), newIndex});
936     }
937     rewriter.replaceOp(indexOp, newIndex);
938   }
939 }
940 
941 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op
942 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
943 /// that those conditions have been satisfied.
944 static Optional<SmallVector<Value>>
945 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
946                            OpOperand *fusableOpOperand,
947                            PatternRewriter &rewriter) {
948   assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
949          "preconditions for fuse operation failed");
950   // Check if reshape is expanding or collapsing.
951   auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
952   auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
953   bool isExpanding = (expandingReshapeOp != nullptr);
954   RankedTensorType expandedType = isExpanding
955                                       ? expandingReshapeOp.getResultType()
956                                       : collapsingReshapeOp.getSrcType();
957   RankedTensorType collapsedType = isExpanding
958                                        ? expandingReshapeOp.getSrcType()
959                                        : collapsingReshapeOp.getResultType();
960 
961   ExpansionInfo expansionInfo;
962   if (failed(expansionInfo.compute(
963           genericOp, fusableOpOperand,
964           isExpanding ? expandingReshapeOp.getReassociationMaps()
965                       : collapsingReshapeOp.getReassociationMaps(),
966           expandedType.getShape(), collapsedType.getShape(), rewriter)))
967     return llvm::None;
968 
969   if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
970     return llvm::None;
971 
972   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
973       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
974         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
975       }));
976 
977   SmallVector<Value> expandedOpOperands;
978   expandedOpOperands.reserve(genericOp.getNumInputs());
979   for (OpOperand *opOperand : genericOp.getInputOperands()) {
980     if (opOperand == fusableOpOperand) {
981       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
982                                                : collapsingReshapeOp.src());
983       continue;
984     }
985     if (genericOp.isInputTensor(opOperand)) {
986       AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
987       auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
988       RankedTensorType expandedOperandType =
989           getExpandedType(opOperandType, indexingMap, expansionInfo);
990       if (expandedOperandType != opOperand->get().getType()) {
991         // Reshape the operand to get the right type.
992         SmallVector<ReassociationIndices> reassociation =
993             getReassociationForExpansion(indexingMap, expansionInfo);
994         if (failed(reshapeLikeShapesAreCompatible(
995                 [&](const Twine &msg) {
996                   return rewriter.notifyMatchFailure(genericOp, msg);
997                 },
998                 opOperandType.getShape(), expandedOperandType.getShape(),
999                 reassociation,
1000                 /*isExpandingReshape=*/true)))
1001           return llvm::None;
1002         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
1003             genericOp.getLoc(), expandedOperandType, opOperand->get(),
1004             reassociation));
1005         continue;
1006       }
1007     }
1008     expandedOpOperands.push_back(opOperand->get());
1009   }
1010 
1011   Location loc = genericOp.getLoc();
1012   SmallVector<Value> outputs;
1013   for (OpOperand *opOperand : genericOp.getOutputOperands()) {
1014     AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
1015     auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
1016     RankedTensorType expandedOutputType =
1017         getExpandedType(opOperandType, indexingMap, expansionInfo);
1018     if (expandedOutputType != opOperand->get().getType()) {
1019       SmallVector<ReassociationIndices> reassociation =
1020           getReassociationForExpansion(indexingMap, expansionInfo);
1021       if (failed(reshapeLikeShapesAreCompatible(
1022               [&](const Twine &msg) {
1023                 return rewriter.notifyMatchFailure(genericOp, msg);
1024               },
1025               opOperandType.getShape(), expandedOutputType.getShape(),
1026               reassociation,
1027               /*isExpandingReshape=*/true)))
1028         return llvm::None;
1029       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
1030           genericOp.getLoc(), expandedOutputType, opOperand->get(),
1031           reassociation));
1032     }
1033   }
1034 
1035   // The iterator types of the expanded op are all parallel.
1036   SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
1037                                        getParallelIteratorTypeName());
1038 
1039   TypeRange resultTypes = ValueRange(outputs).getTypes();
1040   auto fusedOp =
1041       rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
1042                                  /*inputs=*/expandedOpOperands, outputs,
1043                                  expandedOpIndexingMaps, iteratorTypes);
1044   Region &fusedRegion = fusedOp->getRegion(0);
1045   Region &originalRegion = genericOp->getRegion(0);
1046   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
1047 
1048   // Update the index accesses after the expansion.
1049   updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
1050 
1051   // Reshape the result values to their original shape if this is a collapsing
1052   // reshape folded into its consumer.
1053   SmallVector<Value> resultVals;
1054   for (OpResult opResult : genericOp->getOpResults()) {
1055     int64_t resultNumber = opResult.getResultNumber();
1056     if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
1057       SmallVector<ReassociationIndices> reassociation =
1058           getReassociationForExpansion(
1059               genericOp.getTiedIndexingMap(
1060                   genericOp.getOutputOperand(resultNumber)),
1061               expansionInfo);
1062       resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
1063           genericOp.getLoc(), opResult.getType(),
1064           fusedOp->getResult(resultNumber), reassociation));
1065     } else {
1066       resultVals.push_back(fusedOp->getResult(resultNumber));
1067     }
1068   }
1069   // Assuming a single result.
1070   return resultVals;
1071 }
1072 
1073 namespace {
1074 
1075 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
1076 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
1077 /// in the consumer is expanded.
1078 class FoldWithProducerReshapeOpByExpansion
1079     : public OpRewritePattern<GenericOp> {
1080 public:
1081   FoldWithProducerReshapeOpByExpansion(
1082       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1083       PatternBenefit benefit = 1)
1084       : OpRewritePattern<GenericOp>(context, benefit),
1085         controlFoldingReshapes(std::move(foldReshapes)) {}
1086 
1087   LogicalResult matchAndRewrite(GenericOp genericOp,
1088                                 PatternRewriter &rewriter) const override {
1089     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1090       tensor::CollapseShapeOp reshapeOp =
1091           opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1092       if (!reshapeOp)
1093         continue;
1094       // Fold only if
1095       // - The tensor reshape op is folding.
1096       // - All constraints of fusing with reshape by expansion are met.
1097       if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
1098           (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
1099         continue;
1100 
1101       Optional<SmallVector<Value>> replacementValues =
1102           fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
1103       if (!replacementValues)
1104         return failure();
1105       rewriter.replaceOp(genericOp, replacementValues.getValue());
1106       return success();
1107     }
1108     return failure();
1109   }
1110 
1111 private:
1112   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1113 };
1114 
1115 /// Pattern to fold a tensor_expand_shape op with its producer generic op
1116 /// by expanding the dimensionality of the loop in the producer op.
1117 struct FoldReshapeWithGenericOpByExpansion
1118     : public OpRewritePattern<tensor::ExpandShapeOp> {
1119 
1120   FoldReshapeWithGenericOpByExpansion(
1121       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1122       PatternBenefit benefit = 1)
1123       : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1124         controlFoldingReshapes(std::move(foldReshapes)) {}
1125 
1126   LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1127                                 PatternRewriter &rewriter) const override {
1128     // Fold only if all constraints of fusing with reshape by expansion are met.
1129     GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
1130     if (!producer || producer.getNumOutputs() != 1 ||
1131         !isFusableWithReshapeByDimExpansion(producer,
1132                                             producer.getOutputOperand(0)) ||
1133         !controlFoldingReshapes(producer->getResult(0),
1134                                 reshapeOp->getOpOperand(0)))
1135       return failure();
1136     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
1137         producer, reshapeOp, producer.getOutputOperand(0), rewriter);
1138     if (!replacementValues)
1139       return failure();
1140     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
1141     return success();
1142   }
1143 
1144 private:
1145   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1146 };
1147 } // namespace
1148 
1149 //===---------------------------------------------------------------------===//
1150 // Methods and patterns to fuse reshape with linalg.generic operations by
1151 // contraction of dimensions.
1152 //===---------------------------------------------------------------------===//
1153 
1154 /// For an `indexingMap` that is a projected permutation, if the range is to be
1155 /// collapsed using the given `reassociation`, get the reassociation in the
1156 /// domain that would keep the map a projected permutation.
1157 static SmallVector<ReassociationIndices>
1158 getDomainReassociation(AffineMap indexingMap,
1159                        ArrayRef<ReassociationIndices> rangeReassociation) {
1160   assert(indexingMap.isProjectedPermutation() &&
1161          "expected projected permutation map");
1162   unsigned counter = 0;
1163   SmallVector<ReassociationIndices> domainReassociation;
1164   llvm::SmallDenseSet<unsigned, 4> processedDomainDims;
1165   // Iterate over the reassociation indices.
1166   for (ReassociationIndicesRef foldedRangeDims : rangeReassociation) {
1167     ReassociationIndices foldedDomainDims;
1168     for (auto rangeDim : foldedRangeDims) {
1169       (void)rangeDim;
1170       AffineDimExpr dimExpr =
1171           indexingMap.getResult(counter++).cast<AffineDimExpr>();
1172       foldedDomainDims.push_back(dimExpr.getPosition());
1173       processedDomainDims.insert(dimExpr.getPosition());
1174     }
1175     domainReassociation.emplace_back(std::move(foldedDomainDims));
1176   }
1177   // Fill in the missing domain dims.
1178   for (auto dim : llvm::seq<unsigned>(0, indexingMap.getNumDims())) {
1179     if (processedDomainDims.count(dim))
1180       continue;
1181     ReassociationIndices vec = {dim};
1182     domainReassociation.emplace_back(std::move(vec));
1183   }
1184 
1185   // Sort the reassociation using the first dimension of the folded range to
1186   // not create unnecessary transposes.
1187   llvm::sort(domainReassociation,
1188              [](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
1189                return lhs[0] < rhs[0];
1190              });
1191   return domainReassociation;
1192 }
1193 
1194 /// For a given `dimSequence`, check if the sequence is conserved in the
1195 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1196 /// Non-existence of the sequence returns true as well.
1197 static bool isDimSequencePreserved(AffineMap indexingMap,
1198                                    ReassociationIndicesRef dimSequence) {
1199   assert(!dimSequence.empty() &&
1200          "expected non-empty list for dimension sequence");
1201   assert(indexingMap.isProjectedPermutation() &&
1202          "expected indexing map to be projected permutation");
1203 
1204   llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1205   sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1206 
1207   unsigned dimSequenceStart = dimSequence[0];
1208   for (const auto &expr : enumerate(indexingMap.getResults())) {
1209     unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition();
1210     // 1.  Check if this start of the sequence.
1211     if (dimInMapStart == dimSequenceStart) {
1212       if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1213         return false;
1214       // 1a. Check if sequence is preserved.
1215       for (const auto &dimInSequence : enumerate(dimSequence)) {
1216         unsigned dimInMap =
1217             indexingMap.getResult(expr.index() + dimInSequence.index())
1218                 .cast<AffineDimExpr>()
1219                 .getPosition();
1220         if (dimInMap != dimInSequence.value())
1221           return false;
1222       }
1223       // Found the sequence. Projected permutation
1224       // enforces that all AffineDimExprs in the result are unique, so no
1225       // further checks are needed.
1226       return true;
1227     }
1228     // 2. If position in the expr (which is of type AffineDimExpr) is part
1229     // of sequence, return false here. This implies the entire sequence does not
1230     // exist in the indexing map.
1231     if (sequenceElements.count(dimInMapStart))
1232       return false;
1233   }
1234   // 3. No element of sequence found. Return true.
1235   return true;
1236 }
1237 
1238 // Check if a generic op can be fused along an operand by collapsing dimensions.
1239 static bool isFusableWithReshapeByDimCollapse(
1240     GenericOp genericOp, OpOperand *fusableOperand,
1241     ArrayRef<ReassociationIndices> reassociation) {
1242   // Some basic checks for this fusion to be valid.
1243   if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1)
1244     return false;
1245 
1246   if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) {
1247         return map.isProjectedPermutation();
1248       })) {
1249     return false;
1250   }
1251 
1252   // Get the reassociation for the iteration space.
1253   SmallVector<ReassociationIndices> iterationReassociation =
1254       getDomainReassociation(genericOp.getTiedIndexingMap(fusableOperand),
1255                              reassociation);
1256   if (iterationReassociation.empty()) {
1257     // If the domain reassociation indices is empty, then this is a scalar op.
1258     // Nothing to do.
1259     return false;
1260   }
1261 
1262   auto iteratorTypes = genericOp.iterator_types().getValue();
1263   ArrayRef<Attribute> iteratorTypesRef(iteratorTypes);
1264   for (ReassociationIndicesRef foldedIterDims : iterationReassociation) {
1265     // Check that for all indexing maps, the folded dimensions sequence is
1266     // preserved.
1267     if (!llvm::all_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) {
1268           return isDimSequencePreserved(indexingMap, foldedIterDims);
1269         }))
1270       return false;
1271     unsigned startDim = foldedIterDims[0];
1272     ArrayRef<Attribute> foldedIteratorTypes =
1273         iteratorTypesRef.drop_front(startDim).take_front(foldedIterDims.size());
1274     // Check that all folded iterator types are either all parallel, or all
1275     // reduction.
1276     if (!llvm::all_of(
1277             foldedIteratorTypes,
1278             [](Attribute attr) { return isParallelIterator(attr); }) &&
1279         !llvm::all_of(foldedIteratorTypes,
1280                       [](Attribute attr) { return isReductionIterator(attr); }))
1281       return false;
1282   }
1283   return true;
1284 }
1285 
1286 /// Helper class to carry state while collapsing the `linalg.generic` op.
1287 namespace {
1288 class CollapsingInfo {
1289 public:
1290   CollapsingInfo(SmallVector<ReassociationIndices> &&reassociation) {
1291     iterationReassociation = std::move(reassociation);
1292     for (const auto &foldedIterDims : enumerate(iterationReassociation)) {
1293       foldedDimStartToSequenceMap[foldedIterDims.value()[0]] =
1294           foldedIterDims.index();
1295     }
1296   }
1297 
1298   // Returns the iteration space reassociation.
1299   ArrayRef<ReassociationIndices> getReassociationIndices() {
1300     return iterationReassociation;
1301   }
1302 
1303   // Returns true if the given dimension is the start of a sequence of folded
1304   // dimensions.
1305   bool isDimStartOfFoldedDims(unsigned dim) {
1306     return foldedDimStartToSequenceMap.count(dim);
1307   }
1308 
1309   // Return the folded dimensions starting at `dim`.
1310   ReassociationIndicesRef getFoldedDimsStartingAt(unsigned dim) {
1311     assert(foldedDimStartToSequenceMap.count(dim) &&
1312            "invalid start dim of folded dim "
1313            "sequence");
1314     return iterationReassociation[foldedDimStartToSequenceMap[dim]];
1315   }
1316 
1317   // For a dim in the original op, return the dim in the collapsed op, that it
1318   // is mapped to. Expectes `dim` to be start of a folded dimension sequence.
1319   unsigned getDimInCollapsedOpForStartOfFoldedDims(unsigned dim) {
1320     assert(foldedDimStartToSequenceMap.count(dim) &&
1321            "invalid start dim of folded dim sequence");
1322     return foldedDimStartToSequenceMap[dim];
1323   }
1324 
1325 private:
1326   /// Reassociation describing the folded iteration space dimensions.
1327   SmallVector<ReassociationIndices> iterationReassociation;
1328 
1329   /// Map from the starting dimensions of the folded dimension sequences to
1330   /// their index in `iterationReassociation`.
1331   llvm::DenseMap<unsigned, unsigned> foldedDimStartToSequenceMap;
1332 };
1333 } // namespace
1334 
1335 /// Get the iterator types for the collapsed operation given the original
1336 /// iterator types and collapsed dimensions.
1337 static SmallVector<StringRef>
1338 getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
1339                             CollapsingInfo &collapsingInfo) {
1340   SmallVector<StringRef> collapsedIteratorTypes;
1341   for (ReassociationIndicesRef foldedIterDims :
1342        collapsingInfo.getReassociationIndices()) {
1343     assert(!foldedIterDims.empty() &&
1344            "reassociation indices expected to have non-empty sets");
1345     // Just pick the iterator type of the first folded dim. Pre-condition checks
1346     // expected to have checked that iterator types of all folded dimensions are
1347     // the same.
1348     collapsedIteratorTypes.push_back(
1349         iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue());
1350   }
1351   return collapsedIteratorTypes;
1352 }
1353 
1354 /// Compute the indexing map in the collapsed op that corresponds to the given
1355 /// `indexingMap` of the original operation.
1356 static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap,
1357                                            CollapsingInfo &collapsingInfo) {
1358   MLIRContext *context = indexingMap.getContext();
1359   assert(indexingMap.isProjectedPermutation() &&
1360          "expected indexing map to be projected permutation");
1361   SmallVector<AffineExpr> resultExprs;
1362   for (auto expr : indexingMap.getResults()) {
1363     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
1364     if (collapsingInfo.isDimStartOfFoldedDims(dim)) {
1365       resultExprs.push_back(getAffineDimExpr(
1366           collapsingInfo.getDimInCollapsedOpForStartOfFoldedDims(dim),
1367           context));
1368     }
1369   }
1370   return AffineMap::get(collapsingInfo.getReassociationIndices().size(), 0,
1371                         resultExprs, context);
1372 }
1373 
1374 /// Return the `reassociation` indices to use to collapse the operand when the
1375 /// iteration space of a generic op is collapsed.
1376 static SmallVector<ReassociationIndices>
1377 getOperandReassociation(AffineMap indexingMap, CollapsingInfo &collapsingInfo) {
1378   unsigned counter = 0;
1379   SmallVector<ReassociationIndices> operandReassociation;
1380   for (auto expr : indexingMap.getResults()) {
1381     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
1382     if (collapsingInfo.isDimStartOfFoldedDims(dim)) {
1383       unsigned numFoldedDims =
1384           collapsingInfo.getFoldedDimsStartingAt(dim).size();
1385       auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1386       operandReassociation.emplace_back(range.begin(), range.end());
1387       counter += numFoldedDims;
1388     }
1389   }
1390   return operandReassociation;
1391 }
1392 
1393 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1394 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
1395                                    OpOperand *opOperand,
1396                                    CollapsingInfo &collapsingInfo,
1397                                    OpBuilder &builder) {
1398   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
1399   SmallVector<ReassociationIndices> operandReassociation =
1400       getOperandReassociation(indexingMap, collapsingInfo);
1401 
1402   // If the number of entries in the reassocation for the operand is same as the
1403   // number of results of the indexing map, then nothing to do for this operand.
1404   Value operand = opOperand->get();
1405   if (operandReassociation.size() == indexingMap.getNumResults())
1406     return operand;
1407 
1408   // Insert a reshape to collapse the dimensions.
1409   auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
1410       loc, operand, operandReassociation);
1411   return reshapeOp.getResult();
1412 }
1413 
1414 /// Modify the `linalg.index` operations in the original generic op, to its
1415 /// value in the collapsed operation.
1416 void generateCollapsedIndexingRegion(Location loc, Block *block,
1417                                      CollapsingInfo &collapsingInfo,
1418                                      ValueRange loopRange,
1419                                      PatternRewriter &rewriter) {
1420   OpBuilder::InsertionGuard g(rewriter);
1421   rewriter.setInsertionPointToStart(block);
1422 
1423   // Collect all the original index ops.
1424   auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1425 
1426   // For each folded dimension list resolve the original induction variable
1427   // values in terms of the folded dimension induction variable.
1428   //   i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1429   // can be inverted to
1430   //   i2 = i_{folded} % d2
1431   //   i1 = (i_{folded} / d2) % d1
1432   //   i0 = i_{folded} / (d1 * d2)
1433   llvm::DenseMap<unsigned, Value> indexReplacementVals;
1434   for (auto &foldedDims : enumerate(collapsingInfo.getReassociationIndices())) {
1435     ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1436     Value newIndexVal =
1437         rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1438     for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1439       indexReplacementVals[dim] =
1440           rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1441       newIndexVal =
1442           rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1443     }
1444     indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1445   }
1446 
1447   for (auto indexOp : indexOps) {
1448     auto dim = indexOp.dim();
1449     rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1450   }
1451 }
1452 
1453 /// Implementation of fusion with reshape operation by collapsing dimensions.
1454 static Optional<SmallVector<Value>>
1455 fuseWithReshapeByCollapsing(GenericOp genericOp, Operation *reshapeOp,
1456                             OpOperand *fusableOpOperand,
1457                             PatternRewriter &rewriter) {
1458   SmallVector<ReassociationIndices> reassociation =
1459       isa<tensor::CollapseShapeOp>(reshapeOp)
1460           ? cast<tensor::CollapseShapeOp>(reshapeOp).getReassociationIndices()
1461           : cast<tensor::ExpandShapeOp>(reshapeOp).getReassociationIndices();
1462   assert(isFusableWithReshapeByDimCollapse(genericOp, fusableOpOperand,
1463                                            reassociation) &&
1464          "preconditions for fusing with reshape by collapse failed");
1465 
1466   CollapsingInfo collapsingInfo(getDomainReassociation(
1467       genericOp.getTiedIndexingMap(fusableOpOperand), reassociation));
1468   // Check for trivial no transformation cases. In that case return nothing.
1469   if (collapsingInfo.getReassociationIndices().size() ==
1470       genericOp.getNumLoops())
1471     return llvm::None;
1472 
1473   // Get the iterator types for the operand.
1474   SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes(
1475       genericOp.iterator_types().getValue(), collapsingInfo);
1476 
1477   // Get the indexing maps.
1478   auto indexingMaps = llvm::to_vector(
1479       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap map) {
1480         return getCollapsedOpIndexingMap(map, collapsingInfo);
1481       }));
1482 
1483   Location loc =
1484       rewriter.getFusedLoc({genericOp->getLoc(), reshapeOp->getLoc()});
1485 
1486   // Get the input operands.
1487   auto inputOperands = llvm::to_vector(
1488       llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) {
1489         return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
1490                                      rewriter);
1491       }));
1492 
1493   // Get the output operands and result types.
1494   SmallVector<Type> resultTypes;
1495   SmallVector<Value> outputOperands;
1496   resultTypes.reserve(genericOp.getNumOutputs());
1497   outputOperands.reserve(genericOp.getNumOutputs());
1498   for (OpOperand *output : genericOp.getOutputOperands()) {
1499     Value newOutput =
1500         getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter);
1501     outputOperands.push_back(newOutput);
1502     resultTypes.push_back(newOutput.getType());
1503   }
1504 
1505   // Create the generic op.
1506   auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
1507       loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1508       iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1509   Block *origOpBlock = &genericOp->getRegion(0).front();
1510   Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
1511   rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1512                        collapsedOpBlock->getArguments());
1513 
1514   if (collapsedGenericOp.hasIndexSemantics()) {
1515     // Collect the loop range of the generic op.
1516     OpBuilder::InsertionGuard g(rewriter);
1517     rewriter.setInsertionPoint(collapsedGenericOp);
1518     SmallVector<Range> loopRanges =
1519         cast<LinalgOp>(genericOp.getOperation())
1520             .createLoopRanges(rewriter, genericOp.getLoc());
1521     assert(llvm::all_of(loopRanges,
1522                         [](Range range) {
1523                           return matchPattern(range.offset, m_Zero()) &&
1524                                  matchPattern(range.stride, m_One());
1525                         }) &&
1526            "expected all loop ranges to have zero start and unit stride");
1527     SmallVector<Value> loopBound = llvm::to_vector(
1528         llvm::map_range(loopRanges, [](Range range) { return range.size; }));
1529     generateCollapsedIndexingRegion(loc,
1530                                     &collapsedGenericOp->getRegion(0).front(),
1531                                     collapsingInfo, loopBound, rewriter);
1532   }
1533 
1534   // Insert expanding reshape for the result to get back the original result
1535   // type.
1536   SmallVector<Value> results;
1537   for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
1538     Value collapsedOpResult =
1539         collapsedGenericOp->getResult(originalResult.index());
1540     auto originalResultType =
1541         originalResult.value().getType().cast<ShapedType>();
1542     auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
1543     if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1544       AffineMap indexingMap =
1545           genericOp.getTiedIndexingMapForResult(originalResult.value());
1546       SmallVector<ReassociationIndices> reassociation =
1547           getOperandReassociation(indexingMap, collapsingInfo);
1548       Value result = rewriter.create<tensor::ExpandShapeOp>(
1549           loc, originalResultType, collapsedOpResult, reassociation);
1550       results.push_back(result);
1551     } else {
1552       results.push_back(collapsedOpResult);
1553     }
1554   }
1555   return results;
1556 }
1557 
1558 namespace {
1559 
1560 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1561 /// contracting dimensions of the loop.
1562 class FoldWithProducerReshapeOpByCollapsing
1563     : public OpRewritePattern<GenericOp> {
1564 public:
1565   FoldWithProducerReshapeOpByCollapsing(
1566       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1567       PatternBenefit benefit = 1)
1568       : OpRewritePattern<GenericOp>(context, benefit),
1569         controlFoldingReshapes(std::move(foldReshapes)) {}
1570 
1571   LogicalResult matchAndRewrite(GenericOp genericOp,
1572                                 PatternRewriter &rewriter) const override {
1573     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1574       tensor::ExpandShapeOp reshapeOp =
1575           opOperand->get().getDefiningOp<tensor::ExpandShapeOp>();
1576       if (!reshapeOp)
1577         continue;
1578 
1579       if (!isFusableWithReshapeByDimCollapse(
1580               genericOp, opOperand, reshapeOp.getReassociationIndices()) ||
1581           !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
1582         continue;
1583       }
1584 
1585       Optional<SmallVector<Value>> replacements = fuseWithReshapeByCollapsing(
1586           genericOp, reshapeOp, opOperand, rewriter);
1587       if (!replacements) {
1588         return rewriter.notifyMatchFailure(
1589             genericOp, "failed to do the fusion by collapsing transformation");
1590       }
1591 
1592       rewriter.replaceOp(genericOp, replacements.getValue());
1593       return success();
1594     }
1595     return failure();
1596   }
1597 
1598 private:
1599   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1600 };
1601 } // namespace
1602 
1603 //===---------------------------------------------------------------------===//
1604 // Methods and patterns to convert tensor.expand_shape -> linalg.generic
1605 // into linalg.generic -> tensor.expand_shape, i.e. push the reshape down.
1606 //===---------------------------------------------------------------------===//
1607 
1608 // TODO(ravishankarm): This pattern is to be deprecated in favor of fusion by
1609 // collapsing that provides a more general functionality. This pattern is very
1610 // specific to a particular use case. The fusion by collapsing can provide the
1611 // same control to clients using the control function there.
1612 
1613 static SmallVector<ReassociationIndices>
1614 getReassociationIndices(ArrayRef<AffineMap> maps) {
1615   SmallVector<ReassociationIndices> reassociation;
1616   for (AffineMap map : maps) {
1617     ReassociationIndices indices;
1618     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
1619       unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
1620       indices.push_back(pos);
1621     }
1622     reassociation.push_back(indices);
1623   }
1624   return reassociation;
1625 }
1626 
1627 namespace {
1628 /// Pattern to move rank reducing reshape after an elementwise linalg generic
1629 /// op. This is useful to expose more fusion opportunities between named ops and
1630 /// generic ops. This can only be done if there is no broadcast or permuation
1631 /// within the dimensions we need to merge.
1632 ///
1633 /// For example,
1634 ///
1635 ///  %0 = tensor.expand_shape %A [[0, 1], [2]]
1636 ///      : tensor<12544x16xf32> into tensor<112x112x16xf32>
1637 ///  %2 = linalg.generic {indexing_maps = [
1638 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1639 ///    affine_map<(d0, d1, d2) -> (d2)>,
1640 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
1641 ///    ["parallel", "parallel", "parallel"]} {
1642 ///  } -> tensor<112x112x16xf32>
1643 ///
1644 ///  into
1645 ///
1646 ///  %2 = linalg.generic {indexing_maps = [
1647 ///    affine_map<(d0, d1) -> (d0, d1)>,
1648 ///    affine_map<(d0, d1) -> (d1)>,
1649 ///    affine_map<(d0, d1) -> (d0, d1)>],
1650 ///    iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
1651 ///    : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
1652 ///  } -> tensor<12544x16xf32>
1653 ///  %3 = tensor.expand_shape %2 [[0, 1], [2]]
1654 ///    : tensor<12544x16xf32> into tensor<112x112x16xf32>
1655 struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
1656   using OpRewritePattern<GenericOp>::OpRewritePattern;
1657 
1658   LogicalResult matchAndRewrite(GenericOp genericOp,
1659                                 PatternRewriter &rewriter) const override {
1660     // Only apply to elementwise linalg on tensor.
1661     if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() ||
1662         genericOp.getNumParallelLoops() != genericOp.getNumLoops())
1663       return failure();
1664     // Only support identity output maps. It could be extended to permuations if
1665     // needed.
1666     if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
1667           return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
1668         }))
1669       return failure();
1670     int64_t destRank = genericOp.getNumParallelLoops();
1671     SmallVector<Value> newOperands = genericOp.getInputOperands();
1672     tensor::ExpandShapeOp reshapeFound;
1673     // 1. Look for tensor_expand_shape operands and figure out save the
1674     // dimensions merged.
1675     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
1676     for (const auto &en : llvm::enumerate(inputOperands)) {
1677       auto reshapeOp =
1678           en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>();
1679       if (!reshapeOp)
1680         continue;
1681       // TODO: We could support non-identity map as long as the merged
1682       // dimensions are still contiguous.
1683       if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
1684         continue;
1685       if (reshapeFound) {
1686         // Only support a second reshape op if it has the same reassociate maps.
1687         if (reshapeFound.getReassociationMaps() ==
1688             reshapeOp.getReassociationMaps())
1689           newOperands[en.index()] = reshapeOp.src();
1690         continue;
1691       }
1692       reshapeFound = reshapeOp;
1693       newOperands[en.index()] = reshapeOp.src();
1694     }
1695     if (!reshapeFound)
1696       return failure();
1697 
1698     // Calculate the reassociation indices and rassociated reverse map.
1699     SmallVector<ReassociationIndices> reassociation =
1700         getReassociationIndices(reshapeFound.getReassociationMaps());
1701     SmallVector<unsigned> remap(destRank);
1702     for (auto &indices : llvm::enumerate(reassociation)) {
1703       for (int64_t index : indices.value()) {
1704         remap[index] = indices.index();
1705       }
1706     }
1707     // 2. Verify that we can merge the dimensions in the linalg and that we
1708     // don't need to create new reshapes operands. Inserting new reshape
1709     // operands would defeat the purpose of the transformation.
1710     for (const auto &en : llvm::enumerate(inputOperands)) {
1711       if (en.value()->get() == newOperands[en.index()]) {
1712         AffineMap map = genericOp.getTiedIndexingMap(en.value());
1713         for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
1714           if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
1715             return failure();
1716         }
1717       }
1718     }
1719 
1720     // 3. Calculate the affine map remapping and the reassociation to apply to
1721     // output tensors.
1722     SmallVector<AffineMap> newMaps;
1723     unsigned newRank = reassociation.size();
1724     for (auto map : genericOp.getIndexingMaps()) {
1725       SmallVector<AffineExpr> newExprs;
1726       for (auto expr : map.getResults()) {
1727         unsigned position = expr.template cast<AffineDimExpr>().getPosition();
1728         // Skip dimension merged except for the last of the group.
1729         if (reassociation[remap[position]].back() == position) {
1730           newExprs.push_back(
1731               getAffineDimExpr(remap[position], genericOp.getContext()));
1732         }
1733       }
1734       newMaps.push_back(
1735           AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
1736     }
1737 
1738     // 4. Reshape the output tensors.
1739     SmallVector<Value> newOutputs;
1740     SmallVector<Type> newOutputTypes;
1741     for (auto output : genericOp.outputs()) {
1742       auto newOutputType = RankedTensorType::get(
1743           reshapeFound.getSrcType().getShape(),
1744           output.getType().template cast<RankedTensorType>().getElementType());
1745       Value newOutput = rewriter.create<tensor::CollapseShapeOp>(
1746           genericOp->getLoc(), newOutputType, output, reassociation);
1747       newOutputTypes.push_back(newOutputType);
1748       newOutputs.push_back(newOutput);
1749     }
1750     // 5. Create a new generic op with lowerer rank.
1751     SmallVector<StringRef> iteratorTypes(newRank,
1752                                          getParallelIteratorTypeName());
1753     auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
1754                                             newOperands, newOutputs, newMaps,
1755                                             iteratorTypes);
1756     rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
1757                                 newOp.region().begin());
1758     // 6. Reshape the so that the type matches the uses.
1759     SmallVector<Value> newResults;
1760     for (const auto &result : llvm::enumerate(newOp->getResults())) {
1761       newResults.push_back(rewriter.create<tensor::ExpandShapeOp>(
1762           genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
1763           result.value(), reassociation));
1764     }
1765     rewriter.replaceOp(genericOp, newResults);
1766     return success();
1767   }
1768 };
1769 } // namespace
1770 
1771 //===---------------------------------------------------------------------===//
1772 // Methods and patterns that fuse constants with linalg.generic operations.
1773 //===---------------------------------------------------------------------===//
1774 
1775 namespace {
1776 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1777 /// handle cases where the constant is not single-valued.
1778 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1779 public:
1780   FoldScalarOrSplatConstant(MLIRContext *context,
1781                             ControlElementwiseOpsFusionFn &fun,
1782                             PatternBenefit benefit = 1)
1783       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
1784 
1785   LogicalResult matchAndRewrite(GenericOp genericOp,
1786                                 PatternRewriter &rewriter) const override {
1787     if (!genericOp.hasTensorSemantics())
1788       return failure();
1789     for (OpOperand *opOperand : genericOp.getInputOperands()) {
1790       Operation *def = opOperand->get().getDefiningOp();
1791       Attribute constantAttr;
1792       auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1793         {
1794           DenseElementsAttr splatAttr;
1795           if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1796               splatAttr.isSplat() &&
1797               splatAttr.getType().getElementType().isIntOrFloat()) {
1798             constantAttr = splatAttr.getSplatValue<Attribute>();
1799             return true;
1800           }
1801         }
1802         {
1803           IntegerAttr intAttr;
1804           if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1805             constantAttr = intAttr;
1806             return true;
1807           }
1808         }
1809         {
1810           FloatAttr floatAttr;
1811           if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1812             constantAttr = floatAttr;
1813             return true;
1814           }
1815         }
1816         return false;
1817       };
1818 
1819       auto resultValue = opOperand->get().dyn_cast<OpResult>();
1820       if (!def || !resultValue || !isScalarOrSplatConstantOp(def) ||
1821           !controlFn(resultValue, *opOperand))
1822         continue;
1823 
1824       // The operands and the indexing_maps of the fused operation the same as
1825       // the operands and indexing_maps of the generic operations with the
1826       // values at the constant index dropped.
1827       SmallVector<AffineMap> fusedIndexMaps;
1828       SmallVector<Value> fusedOperands;
1829       SmallVector<Location> fusedLocs{genericOp.getLoc()};
1830       fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1831       fusedOperands.reserve(genericOp.getNumInputs());
1832       fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
1833       for (OpOperand *inputOperand : genericOp.getInputOperands()) {
1834         if (inputOperand == opOperand)
1835           continue;
1836         Value inputValue = inputOperand->get();
1837         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1838         fusedOperands.push_back(inputValue);
1839         fusedLocs.push_back(inputValue.getLoc());
1840       }
1841       for (OpOperand *outputOperand : genericOp.getOutputOperands())
1842         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1843 
1844       // Check if the operation shapes to loops map is computable.
1845       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1846         return rewriter.notifyMatchFailure(
1847             genericOp, "fused op loop bound computation failed");
1848       }
1849 
1850       // Create a constant scalar value from the splat constant.
1851       Value scalarConstant = rewriter.create<arith::ConstantOp>(
1852           def->getLoc(), constantAttr, constantAttr.getType());
1853 
1854       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1855       auto fusedOp = rewriter.create<GenericOp>(
1856           rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1857           /*inputs=*/fusedOperands,
1858           /*outputs=*/outputOperands,
1859           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1860           genericOp.iterator_types(),
1861           /*doc=*/nullptr,
1862           /*library_call=*/nullptr);
1863 
1864       // Map the block argument corresponding to the replaced argument with the
1865       // scalar constant.
1866       Region &region = genericOp->getRegion(0);
1867       Block &entryBlock = *region.begin();
1868       BlockAndValueMapping mapping;
1869       mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1870                   scalarConstant);
1871       Region &fusedRegion = fusedOp->getRegion(0);
1872       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1873                                  mapping);
1874       rewriter.replaceOp(genericOp, fusedOp->getResults());
1875       return success();
1876     }
1877     return failure();
1878   }
1879 
1880 private:
1881   ControlElementwiseOpsFusionFn controlFn;
1882 };
1883 
1884 /// Base class for constant folding linalg.generic ops with N inputs, 1 output,
1885 /// and permutation indexing maps.
1886 ///
1887 /// `ConcreteType` should provide methods with signatures
1888 ///
1889 /// ```c++
1890 ///   bool matchIndexingMaps(GenericOp genericOp) const;
1891 ///   RegionComputationFn getRegionComputeFn(GenericOp) const;
1892 /// ```
1893 ///
1894 /// The latter inspects the region and returns the computation inside as a
1895 /// functor. The functor will be invoked with constant elements for all inputs
1896 /// and should return the corresponding computea constant element for output.
1897 template <typename ConcreteType>
1898 class FoldConstantBase : public OpRewritePattern<GenericOp> {
1899 public:
1900   struct APIntOrFloat {
1901     Optional<APInt> apInt;
1902     Optional<APFloat> apFloat;
1903   };
1904   struct APIntOrFloatArray {
1905     SmallVector<APInt> apInts;
1906     SmallVector<APFloat> apFloats;
1907   };
1908   using RegionComputationFn =
1909       std::function<APIntOrFloat(const APIntOrFloatArray &)>;
1910 
1911   FoldConstantBase(MLIRContext *context,
1912                    const ControlElementwiseOpsFusionFn &controlFn,
1913                    PatternBenefit benefit = 1)
1914       : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
1915 
1916   LogicalResult matchAndRewrite(GenericOp genericOp,
1917                                 PatternRewriter &rewriter) const override {
1918     if (genericOp.hasBufferSemantics())
1919       return failure();
1920 
1921     // Only support ops generating one output for now.
1922     if (genericOp.getNumOutputs() != 1)
1923       return failure();
1924 
1925     auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
1926     // Require the output types to be static give we are generating constants.
1927     if (!outputType || !outputType.hasStaticShape())
1928       return failure();
1929 
1930     if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
1931           return operand->get().getType().isa<ShapedType>();
1932         }))
1933       return failure();
1934 
1935     // Make sure all element types are the same.
1936     auto getOperandElementType = [](OpOperand *operand) {
1937       return operand->get().getType().cast<ShapedType>().getElementType();
1938     };
1939     if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
1940                                         getOperandElementType)))
1941       return failure();
1942 
1943     // We can only handle the case where we have int/float elements.
1944     auto elementType = outputType.getElementType();
1945     if (!elementType.isIntOrFloat())
1946       return failure();
1947 
1948     // Require all indexing maps to be permutations for now. This is common and
1949     // it simplifies input/output access greatly: we can do the data shuffling
1950     // entirely in the compiler, without needing to turn all indices into
1951     // Values, and then do affine apply on them, and then match back the
1952     // constant again.
1953     if (!llvm::all_of(genericOp.getIndexingMaps(),
1954                       [](AffineMap map) { return map.isPermutation(); }))
1955       return failure();
1956 
1957     for (OpOperand *operand : genericOp.getOutputOperands()) {
1958       if (genericOp.payloadUsesValueFromOperand(operand))
1959         return failure();
1960     }
1961 
1962     // Further check the indexing maps are okay for the ConcreteType.
1963     if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
1964       return failure();
1965 
1966     // Defer to the concrete type to check the region and discover the
1967     // computation inside.
1968     RegionComputationFn computeFn =
1969         static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
1970     if (!computeFn)
1971       return failure();
1972 
1973     // All inputs should be constants.
1974     int numInputs = genericOp.getNumInputs();
1975     SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
1976     for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) {
1977       if (!matchPattern(operand.value()->get(),
1978                         m_Constant(&inputValues[operand.index()])))
1979         return failure();
1980     }
1981 
1982     // Identified this as a potential candidate for folding. Now check the
1983     // policy to see whether we are allowed to proceed.
1984     for (int i = 0; i < numInputs; ++i) {
1985       OpOperand *consumer = genericOp.getInputOperand(i);
1986       OpResult producer = consumer->get().cast<OpResult>();
1987       if (!controlFn(producer, *consumer))
1988         return failure();
1989     }
1990 
1991     auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
1992     SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
1993     int64_t numElements = outputType.getNumElements();
1994 
1995     // Use APInt/APFloat instead of Attribute here for constructing the output.
1996     // This helps to avoid blowing up compiler memory usage: Attributes would
1997     // unify the following cases but they have lifetime as the MLIRContext.
1998     SmallVector<APInt> intOutputValues;
1999     SmallVector<APFloat> fpOutputValues;
2000     if (elementType.template isa<FloatType>())
2001       fpOutputValues.resize(numElements, APFloat(0.f));
2002     else
2003       intOutputValues.resize(numElements);
2004 
2005     // Return the constant dim positions from the given permutation map.
2006     auto getDimPositions = [](AffineMap map) {
2007       SmallVector<unsigned> dims;
2008       dims.reserve(map.getNumResults());
2009       for (AffineExpr result : map.getResults()) {
2010         dims.push_back(result.cast<AffineDimExpr>().getPosition());
2011       }
2012       return dims;
2013     };
2014 
2015     SmallVector<SmallVector<unsigned>> inputDims;
2016     for (int i = 0; i < numInputs; ++i)
2017       inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i]));
2018     auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
2019     auto outputShape = outputType.getShape();
2020 
2021     // Allocate small vectors for index delinearization. Initial values do not
2022     // matter here as they will be overwritten later.
2023     SmallVector<uint64_t> indices(loopBounds.size(), 0);
2024     SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
2025     SmallVector<SmallVector<uint64_t>> srcIndices(
2026         numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
2027     SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
2028     uint64_t dstLinearIndex = 0;
2029 
2030     // Allocate spaces for compute function inputs. Initial values do not matter
2031     // here as they will be overwritten later.
2032     APIntOrFloatArray computeFnInputs;
2033 
2034     auto inputShapes = llvm::to_vector<4>(
2035         llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
2036           return operand->get().getType().cast<ShapedType>().getShape();
2037         }));
2038 
2039     // Given a `linearIndex`, remap it to a linear index to access linalg op
2040     // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
2041     // `srcLinearIndices`, `dstLinearIndex` in place.
2042     auto computeRemappedLinearIndex = [&](int linearIndex) {
2043       int totalCount = linearIndex;
2044       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
2045         indices[dim] = totalCount % loopBounds[dim];
2046         totalCount /= loopBounds[dim];
2047       }
2048 
2049       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
2050         for (int i = 0; i < numInputs; ++i)
2051           srcIndices[i][dim] = indices[inputDims[i][dim]];
2052         dstIndices[dim] = indices[outputDims[dim]];
2053       }
2054 
2055       dstLinearIndex = dstIndices.front();
2056       for (int i = 0; i < numInputs; ++i)
2057         srcLinearIndices[i] = srcIndices[i].front();
2058 
2059       for (int dim = 1; dim < outputType.getRank(); ++dim) {
2060         dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
2061         for (int i = 0; i < numInputs; ++i)
2062           srcLinearIndices[i] =
2063               srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
2064       }
2065     };
2066 
2067     bool isFloat = elementType.isa<FloatType>();
2068     if (isFloat) {
2069       SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
2070       for (int i = 0; i < numInputs; ++i)
2071         inFpRanges.push_back(inputValues[i].getValues<APFloat>());
2072 
2073       computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
2074 
2075       // Transpose the input constant. Because we don't know its rank in
2076       // advance, we need to loop over the range [0, element count) and
2077       // delinearize the index.
2078       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
2079         computeRemappedLinearIndex(linearIndex);
2080 
2081         // Collect constant elements for all inputs at this loop iteration.
2082         for (int i = 0; i < numInputs; ++i)
2083           computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
2084 
2085         // Invoke the computation to get the corresponding constant output
2086         // element.
2087         fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
2088       }
2089     } else {
2090       SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
2091       for (int i = 0; i < numInputs; ++i)
2092         inIntRanges.push_back(inputValues[i].getValues<APInt>());
2093 
2094       computeFnInputs.apInts.resize(numInputs);
2095 
2096       // Transpose the input constant. Because we don't know its rank in
2097       // advance, we need to loop over the range [0, element count) and
2098       // delinearize the index.
2099       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
2100         computeRemappedLinearIndex(linearIndex);
2101 
2102         // Collect constant elements for all inputs at this loop iteration.
2103         for (int i = 0; i < numInputs; ++i)
2104           computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
2105 
2106         // Invoke the computation to get the corresponding constant output
2107         // element.
2108         intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
2109       }
2110     }
2111 
2112     DenseElementsAttr outputAttr =
2113         isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
2114                 : DenseElementsAttr::get(outputType, intOutputValues);
2115 
2116     rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
2117     return success();
2118   }
2119 
2120 private:
2121   ControlElementwiseOpsFusionFn controlFn;
2122 };
2123 
2124 // Folds linalg.generic ops that are actually transposes on constant values.
2125 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
2126   using FoldConstantBase::FoldConstantBase;
2127 
2128   bool matchIndexingMaps(GenericOp genericOp) const {
2129     // We should have one input and one output.
2130     return genericOp.getIndexingMaps().size() == 2;
2131   }
2132 
2133   RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
2134     // Make sure the region only contains a yield op.
2135     Block &body = genericOp.region().front();
2136     if (!llvm::hasSingleElement(body))
2137       return nullptr;
2138     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
2139     if (!yieldOp)
2140       return nullptr;
2141 
2142     // The yield op should return the block argument corresponds to the input.
2143     for (Value yieldVal : yieldOp.values()) {
2144       auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
2145       if (!yieldArg || yieldArg.getOwner() != &body)
2146         return nullptr;
2147       if (yieldArg.getArgNumber() != 0)
2148         return nullptr;
2149     }
2150 
2151     // No computation; just return the orginal value.
2152     return [](const APIntOrFloatArray &inputs) {
2153       if (inputs.apFloats.empty())
2154         return APIntOrFloat{inputs.apInts.front(), llvm::None};
2155       return APIntOrFloat{llvm::None, inputs.apFloats.front()};
2156     };
2157   }
2158 
2159   ControlElementwiseOpsFusionFn controlFn;
2160 };
2161 
2162 } // namespace
2163 
2164 //===---------------------------------------------------------------------===//
2165 // Miscellaneous patterns that help fusion.
2166 //===---------------------------------------------------------------------===//
2167 
2168 namespace {
2169 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
2170 /// the value of the `outs` operand is not used within the op.  This is only
2171 /// implemented for `linalg.generic` operations for now, but should hold for all
2172 /// linalg structured ops.
2173 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2174   using OpRewritePattern<GenericOp>::OpRewritePattern;
2175 
2176   LogicalResult matchAndRewrite(GenericOp op,
2177                                 PatternRewriter &rewriter) const override {
2178     rewriter.startRootUpdate(op);
2179     bool modifiedOutput = false;
2180     Location loc = op.getLoc();
2181     for (OpOperand *opOperand : op.getOutputOperands()) {
2182       if (!op.payloadUsesValueFromOperand(opOperand)) {
2183         Value operandVal = opOperand->get();
2184         auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
2185         if (!operandType)
2186           continue;
2187 
2188         // If outs is sparse, leave it to the sparse compiler.
2189         if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
2190           continue;
2191 
2192         // If outs is already an `init_tensor` operation, nothing to do.
2193         auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
2194         if (definingOp)
2195           continue;
2196         modifiedOutput = true;
2197         SmallVector<Value> dynamicDims;
2198         for (const auto &dim : llvm::enumerate(operandType.getShape())) {
2199           if (dim.value() != ShapedType::kDynamicSize)
2200             continue;
2201           dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
2202               loc, operandVal, dim.index()));
2203         }
2204         Value initTensor = rewriter.create<InitTensorOp>(
2205             loc, dynamicDims, operandType.getShape(),
2206             operandType.getElementType());
2207         op->setOperand(opOperand->getOperandNumber(), initTensor);
2208       }
2209     }
2210     if (!modifiedOutput) {
2211       rewriter.cancelRootUpdate(op);
2212       return failure();
2213     }
2214     rewriter.finalizeRootUpdate(op);
2215     return success();
2216   }
2217 };
2218 } // namespace
2219 
2220 //===---------------------------------------------------------------------===//
2221 // Methods that add patterns described in this file to a pattern list.
2222 //===---------------------------------------------------------------------===//
2223 
2224 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
2225     RewritePatternSet &patterns) {
2226   patterns.add<
2227       FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
2228       FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
2229       FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>>(
2230       patterns.getContext());
2231 }
2232 
2233 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
2234     RewritePatternSet &patterns) {
2235   patterns
2236       .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
2237            FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>,
2238            FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>>(
2239           patterns.getContext());
2240 }
2241 
2242 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
2243     RewritePatternSet &patterns,
2244     const ControlElementwiseOpsFusionFn &controlFoldingReshapes) {
2245   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2246                                                     controlFoldingReshapes);
2247   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2248                                                      controlFoldingReshapes);
2249 }
2250 
2251 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
2252     RewritePatternSet &patterns,
2253     const ControlElementwiseOpsFusionFn &controlFoldingReshapes) {
2254   patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2255                                                       controlFoldingReshapes);
2256 }
2257 
2258 void mlir::linalg::populateElementwiseOpsFusionPatterns(
2259     RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
2260   auto *context = patterns.getContext();
2261   patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
2262                FoldConstantTranspose>(context,
2263                                       options.controlElementwiseOpsFusionFn);
2264   patterns.add<RemoveOutsDependency>(context);
2265   populateSparseTensorRewriting(patterns);
2266   populateFoldReshapeOpsByExpansionPatterns(patterns,
2267                                             options.controlFoldingReshapesFn);
2268   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2269   GenericOp::getCanonicalizationPatterns(patterns, context);
2270   tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2271   tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2272   context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2273       patterns);
2274 }
2275 
2276 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
2277   auto *context = patterns.getContext();
2278   patterns.add<PushExpandingReshape>(context);
2279 }
2280 
2281 //===---------------------------------------------------------------------===//
2282 // Passes
2283 //===---------------------------------------------------------------------===//
2284 
2285 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
2286                                       OpOperand &consumer) {
2287   if (auto producerCollapseOp =
2288           dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
2289     return !isUnitDimExpansionOnly(producerCollapseOp);
2290   }
2291   if (auto consumerExpandOp =
2292           dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
2293     return !isUnitDimExpansionOnly(consumerExpandOp);
2294   }
2295   return true;
2296 }
2297 
2298 namespace {
2299 
2300 /// Pass that fuses generic ops on tensors. Used only for testing.
2301 struct LinalgElementwiseOpFusionPass
2302     : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
2303   void runOnOperation() override {
2304     Operation *op = getOperation();
2305     RewritePatternSet patterns(op->getContext());
2306     ControlElementwiseOpsFusionFn allowFoldingFn =
2307         [](const OpResult &producer, const OpOperand &consumer) {
2308           return true;
2309         };
2310     populateElementwiseOpsFusionPatterns(
2311         patterns,
2312         LinalgElementwiseFusionOptions().setControlFoldingReshapes(
2313             allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
2314 
2315     // Use TopDownTraversal for compile time reasons
2316     GreedyRewriteConfig grc;
2317     grc.useTopDownTraversal = true;
2318     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
2319                                        grc);
2320   }
2321 };
2322 
2323 /// Pass to test folding of reshape ops with generic ops by linearization.
2324 struct FoldReshapeOpsByLinearizationPass
2325     : public LinalgFoldReshapeOpsByLinearizationBase<
2326           FoldReshapeOpsByLinearizationPass> {
2327   void runOnOperation() override {
2328     Operation *op = getOperation();
2329     RewritePatternSet patterns(op->getContext());
2330     populateFoldReshapeOpsByLinearizationPatterns(patterns);
2331     if (allowFoldingUnitDimReshapes) {
2332       populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
2333     }
2334     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
2335   }
2336 };
2337 
2338 } // namespace
2339 
2340 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
2341   return std::make_unique<LinalgElementwiseOpFusionPass>();
2342 }
2343 
2344 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
2345   return std::make_unique<FoldReshapeOpsByLinearizationPass>();
2346 }
2347