1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "PassDetail.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Support/LLVM.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 
28 /// Implementation of fusion of generic ops and indexed_generic ops.
29 static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer,
30                                 unsigned consumerIdx) {
31   // Producer and consumer must have tensor semantics.
32   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
33     return false;
34 
35   // Verify that
36   // - the producer has all "parallel" iterator type.
37   if (producer.getNumParallelLoops() != producer.getNumLoops())
38     return false;
39 
40   // Get the consumer index map. The number of results of the consumer index
41   // map must match the number of loops of the producer.
42   AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
43   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
44     return false;
45 
46   // Finally the index_map for the result must be invertible. For now just
47   // verify it is a permutation.
48   AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
49   return producerResultIndexMap.isPermutation();
50 }
51 
52 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
53 /// the `producer` to use in the fused operation given the indexing map of the
54 /// result of the producer in the consumer.
55 static void getIndexingMapOfProducerOperandsInFusedOp(
56     LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
57     SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
58   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
59   // from consumer loop -> consumer arg tensor index/producer result tensor
60   // index. The fused loop is same as the consumer loop. For each producer arg
61   // the indexing map to be computed is a map from consumer loop -> producer
62   // arg tensor index.
63 
64   AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
65   // producerResultIndexMap is a map from producer loop -> tensor index.
66   // Compute the inverse to get map from tensor index -> producer loop.
67   // The inverse is a map from producer result tensor index -> producer loop.
68   AffineMap invProducerResultIndexMap =
69       inversePermutation(producerResultIndexMap);
70   assert(invProducerResultIndexMap &&
71          "expected producer result indexig map to be invertible");
72   for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
73     // argMap is a map from producer loop -> producer arg tensor index.
74     AffineMap argMap = producer.getInputIndexingMap(argNum);
75 
76     // Compose argMap with invProducerResultIndexMap to get a map from
77     // producer result tensor index -> producer arg tensor index.
78     AffineMap t1 = argMap.compose(invProducerResultIndexMap);
79 
80     // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
81     // consumer loop/ fused loop -> producer arg tensor index.
82     AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
83     fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
84   }
85 }
86 
87 /// Generate the region of the fused tensor operation. The region of the fused
88 /// op must be empty.
89 static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
90                                         Operation *fusedOp, LinalgOp producer,
91                                         LinalgOp consumer,
92                                         AffineMap consumerToProducerLoopsMap,
93                                         unsigned consumerIdx, unsigned nloops) {
94   // Build the region of the fused op.
95   Block &producerBlock = producer->getRegion(0).front();
96   Block &consumerBlock = consumer->getRegion(0).front();
97   Block *fusedBlock = new Block();
98   fusedOp->getRegion(0).push_back(fusedBlock);
99   BlockAndValueMapping mapper;
100   OpBuilder::InsertionGuard guard(rewriter);
101   rewriter.setInsertionPointToStart(fusedBlock);
102 
103   // The block arguments are
104   // [index_0, index_1, ... ,
105   //   consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1),
106   //   producer_operand_0, ... , producer_operand_(n-1)],
107   //   consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)]
108   // , where n is the number of producer's operand and m is the number
109   // consumer's operand.
110   // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
111   // generic op. In this case, there are no indices in block arguments.
112   unsigned numProducerIndices = isa<IndexedGenericOp>(producer.getOperation())
113                                     ? producer.getNumLoops()
114                                     : 0;
115   unsigned numConsumerIndices = isa<IndexedGenericOp>(consumer.getOperation())
116                                     ? consumer.getNumLoops()
117                                     : 0;
118   unsigned numFusedOpIndices =
119       (isa<IndexedGenericOp>(producer.getOperation()) ||
120        isa<IndexedGenericOp>(consumer.getOperation()))
121           ? std::max(producer.getNumLoops(), consumer.getNumLoops())
122           : 0;
123   // Firstly, add all the indices to the block arguments.
124   for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i)
125     fusedBlock->addArgument(rewriter.getIndexType());
126   // Map the arguments for the unmodified args from the consumer.
127   for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
128     if (consumerArg.index() == consumerIdx + numConsumerIndices) {
129       // Map the arguments for the args from the producer.
130       for (auto producerArg :
131            llvm::enumerate(producerBlock.getArguments().take_front(
132                producer.getNumInputs() + numProducerIndices))) {
133         // If producer is an indexed_generic op, map the indices from consumer
134         // loop to producer loop (because the fusedOp is built based on
135         // consumer's perspective).
136         if (producerArg.index() < numProducerIndices) {
137           auto newIndex = rewriter.create<mlir::AffineApplyOp>(
138               producer.getLoc(),
139               consumerToProducerLoopsMap.getSubMap(producerArg.index()),
140               fusedBlock->getArguments().take_front(numFusedOpIndices));
141           mapper.map(producerArg.value(), newIndex);
142         } else {
143           mapper.map(producerArg.value(),
144                      fusedBlock->addArgument(producerArg.value().getType()));
145         }
146       }
147       continue;
148     }
149 
150     // If consumer is an indexed_generic op, map the indices to the block
151     // arguments directly. Otherwise, add the same type of argument and map to
152     // it.
153     if (consumerArg.index() < numConsumerIndices) {
154       mapper.map(consumerArg.value(),
155                  fusedBlock->getArgument(consumerArg.index()));
156     } else {
157       mapper.map(consumerArg.value(),
158                  fusedBlock->addArgument(consumerArg.value().getType()));
159     }
160   }
161 
162   // Add operations from producer (except the yield operation) to the fused
163   // op.
164   for (auto &op : producerBlock.getOperations()) {
165     if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) {
166       // Lookup the value the yield operation is mapped to.
167       Value yieldVal = yieldOp.getOperand(0);
168       if (Value clonedVal = mapper.lookupOrNull(yieldVal))
169         mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices),
170                    clonedVal);
171       continue;
172     }
173     rewriter.clone(op, mapper);
174   }
175   for (auto &op : consumerBlock.getOperations())
176     rewriter.clone(op, mapper);
177 }
178 
179 static Optional<SmallVector<Value, 1>>
180 fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
181                   PatternRewriter &rewriter) {
182   LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
183   unsigned consumerIdx = consumerOpOperand.getOperandNumber();
184   if (!areTensorOpsFusable(producer, consumer, consumerIdx))
185     return llvm::None;
186 
187   unsigned numFusedOperands =
188       producer.getNumInputs() + consumer.getNumInputs() - 1;
189 
190   // Compute the fused operands list,
191   SmallVector<Value, 2> fusedOperands;
192   fusedOperands.reserve(numFusedOperands);
193   auto consumerOperands = consumer.getInputs();
194   auto producerOperands = producer.getInputs();
195   fusedOperands.assign(consumerOperands.begin(),
196                        std::next(consumerOperands.begin(), consumerIdx));
197   fusedOperands.append(producerOperands.begin(), producerOperands.end());
198   fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
199                        consumerOperands.end());
200 
201   // Compute indexing_maps for the fused operation. The indexing_maps for the
202   // operands of the consumers that aren't fused are the same. The
203   // indexing_maps for the producers need to be computed based on the
204   // indexing_map of the operand at consumerIdx in the consumer.
205   SmallVector<Attribute, 4> fusedIndexMaps;
206   auto consumerIndexMaps = consumer.indexing_maps();
207   fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumOutputs());
208   fusedIndexMaps.assign(consumerIndexMaps.begin(),
209                         std::next(consumerIndexMaps.begin(), consumerIdx));
210   // Compute indexing maps for the producer args in the fused operation.
211   getIndexingMapOfProducerOperandsInFusedOp(
212       producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
213 
214   // Append the indexing maps for the remaining consumer operands.
215   fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
216                         consumerIndexMaps.end());
217 
218   // Generate the fused op.
219   LinalgOp fusedOp;
220   if (isa<GenericOp>(producer.getOperation()) &&
221       isa<GenericOp>(consumer.getOperation())) {
222     fusedOp =
223         rewriter
224             .create<GenericOp>(consumer.getLoc(), consumer->getResultTypes(),
225                                /*inputs=*/fusedOperands,
226                                // TODO: handle outputs.
227                                consumer.getOutputs(),
228                                rewriter.getArrayAttr(fusedIndexMaps),
229                                consumer.iterator_types(),
230                                /*doc=*/nullptr,
231                                /*library_call=*/nullptr,
232                                /*sparse=*/nullptr)
233             .getOperation();
234   } else {
235     fusedOp =
236         rewriter
237             .create<IndexedGenericOp>(
238                 consumer.getLoc(), consumer->getResultTypes(),
239                 /*inputs=*/fusedOperands,
240                 // TODO: handle outputs.
241                 consumer.getOutputs(), rewriter.getArrayAttr(fusedIndexMaps),
242                 consumer.iterator_types(),
243                 /*doc=*/nullptr,
244                 /*library_call=*/nullptr,
245                 /*sparse=*/nullptr)
246             .getOperation();
247   }
248 
249   // Construct an AffineMap from consumer loops to producer loops.
250   // consumer loop -> tensor index
251   AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx);
252   // producer loop -> tensor index
253   AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
254   // tensor index -> producer loop
255   AffineMap invProducerResultIndexMap =
256       inversePermutation(producerResultIndexMap);
257   assert(invProducerResultIndexMap &&
258          "expected producer result indexig map to be invertible");
259   // consumer loop -> producer loop
260   AffineMap consumerToProducerLoopsMap =
261       invProducerResultIndexMap.compose(consumerResultIndexMap);
262 
263   generateFusedTensorOpRegion(rewriter, fusedOp.getOperation(), producer,
264                               consumer, consumerToProducerLoopsMap, consumerIdx,
265                               consumer.getNumLoops());
266   return SmallVector<Value, 1>(fusedOp->getResults());
267 }
268 
269 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
270 /// provided, given the shape of the source tensor that corresponds to the
271 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
272 /// are "row-major" ordered logically.
273 ///
274 /// For example:
275 ///
276 /// %0 = op ... : tensor<?x?x4x5xf32>
277 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
278 ///
279 /// and reshape:
280 /// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
281 ///                                affine_map<(i, j, k, l) -> (j, k, l)>] :
282 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
283 ///
284 /// would be rewritten into:
285 /// %0 = op ... : tensor<?x?x4x5xf32>
286 /// with output index_map
287 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
288 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
289                                         ArrayRef<int64_t> sourceShape,
290                                         ArrayRef<AffineMap> reassociationMaps) {
291   SmallVector<AffineExpr, 4> resultExprs;
292   resultExprs.reserve(reassociationMaps.size());
293   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
294   MLIRContext *context = sourceMap.getContext();
295 
296   // Compute the result exprs based on the reassociation maps.
297   for (AffineMap map : reassociationMaps) {
298     ArrayRef<AffineExpr> collapsedDims = map.getResults();
299     // Assume that they are in-order and contiguous (already checked in
300     // verifier).
301     assert(!collapsedDims.empty());
302     unsigned startDim =
303         collapsedDims.front().cast<AffineDimExpr>().getPosition();
304     SmallVector<int64_t, 4> sizes;
305     SmallVector<AffineExpr, 4> dimExprs;
306     for (auto en :
307          llvm::zip(sourceShape.slice(startDim, collapsedDims.size()),
308                    sourceExprs.slice(startDim, collapsedDims.size()))) {
309       if (std::get<0>(en) == 1)
310         continue;
311       sizes.push_back(std::get<0>(en));
312       dimExprs.push_back(std::get<1>(en));
313     }
314     AffineExpr linearizedExpr =
315         makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
316     resultExprs.push_back(linearizedExpr);
317   }
318   return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
319                         resultExprs, context);
320 }
321 
322 /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is
323 /// true) or its producer (if `asProducer` is false) given the indexing map at
324 /// its use.
325 static bool isTensorReshapeOpFoldableByLinearization(TensorReshapeOp reshapeOp,
326                                                      AffineMap useIndexMap,
327                                                      bool asProducer) {
328   RankedTensorType returnType = reshapeOp.getResultType();
329   RankedTensorType operandType = reshapeOp.getSrcType();
330   // Reshape is fusable with its consumer (i.e. reshape as a producer) when its
331   // operand is of lesser rank than the result. Fusing when operand has higher
332   // rank will require use of mods and divs in the indexing maps of the fused op
333   // which would make it non-invertible. Similarly reshape is fused with its
334   // producer (i.e. reshape as consumer) only if the return type has lesser
335   // rank.
336   if ((asProducer && reshapeOp.getSrcType().hasStaticShape() &&
337        returnType.getRank() < operandType.getRank()) ||
338       (!asProducer && reshapeOp.getResultType().hasStaticShape() &&
339        operandType.getRank() < returnType.getRank()))
340     return false;
341   return useIndexMap.isPermutation();
342 }
343 
344 /// Based on the type of `op` create a linalg op of the same type, i.e. if `op`
345 /// is a linalg.generic operation, the create a `linalg.generic` operation with
346 /// the given `args`. Expects `op` to be `linalg.generic` or
347 /// `linalg.indexed_generic`.
348 template <typename... Args>
349 static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
350                                          Args... args) {
351   if (isa<GenericOp>(op.getOperation()))
352     return rewriter.create<GenericOp>(args...);
353   if (isa<IndexedGenericOp>(op.getOperation()))
354     return rewriter.create<IndexedGenericOp>(args...);
355   llvm_unreachable(
356       "expected only linalg.generic or linalg.indexed_generic ops");
357   return nullptr;
358 }
359 
360 /// Check if the reshape operation is only expansion into/collapsing of
361 /// unit-dimension.
362 static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
363                                    ArrayRef<AffineMap> reassociation) {
364   for (auto &map : reassociation) {
365     unsigned numUnitDims = 0;
366     for (AffineExpr expr : map.getResults()) {
367       unsigned position = expr.cast<AffineDimExpr>().getPosition();
368       if (expandedShape[position] == 1)
369         numUnitDims++;
370     }
371     if (numUnitDims != map.getNumResults() - 1)
372       return false;
373   }
374   return true;
375 }
376 
377 /// Conditions for folding a generic/indexed-generic operation with a reshape op
378 /// by expanding the iteration space dimensionality for tensor operations. These
379 /// are preconditions assumed by `foldReshapeByDimExpansion` which implements
380 /// the following fusion pattern.
381 ///
382 ///  Consider
383 ///
384 ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
385 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
386 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
387 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
388 ///  %d = linalg.tensor_reshape %c
389 ///         [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
390 ///          affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
391 ///          affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>]
392 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
393 ///
394 ///  The reshape can be folded into the `linalgOp` if the
395 ///  generic/indexed-generic op loop dimensionality is increased to match the
396 ///  result (operand) of the tensor_reshape when the reshape is expanding
397 ///  (folding). The indexing_map of the fused tensor in the `linalgOp` and the
398 ///  reassociation map helps compute the indexing maps of the modified op. For
399 ///  the above example, based on the reassociation map it can be concluded that
400 ///
401 ///  - The loop used to access the first dimension of the fused tensor is split
402 ///    into two.
403 ///  - The loop used to access the second dimension of the fused tensor is kept
404 ///    as is.
405 ///  - The loop used to access the third dimension of the fused tensor is split
406 ///    into three.
407 ///
408 ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
409 ///  op, then
410 ///
411 ///   d0 -> e0, e1
412 ///   d1 -> e2, e3, e4
413 ///   d2 -> e5
414 ///
415 ///  substituting this, the generic op can be rewritten as
416 ///
417 ///  %d = linalg.generic ins(%0, %1 : )
418 ///        indexing_maps =
419 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
420 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
421 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
422 ///
423 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
424 ///  to make it consistent
425 ///
426 ///  %0 = linalg.tensor_reshape %a
427 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e2),
428 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e3, e4),
429 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e5)]
430 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
431 ///  %1 = linalg.tensor_reshape %b
432 ///         [affine_map<(e0, e1, e2, e3) -> (e0, e1, e2),
433 ///          affine_map<(e0, e1, e2, e3) -> (e3)]
434 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
435 ///
436 ///  The added reshapes are again expanding patterns, so they will get fused
437 ///  with its producers if possible.
438 static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
439                                                unsigned fusedTensorIndex) {
440   // Is fusable only if:
441   // - The linalgOp is a generic op, or an indexed_generic.
442   // - All the indexing maps for operands and results in linalgOp are projected
443   //   permutations.
444   // - The fused tensor is not a scalar.
445   // - All the loops in linalgOp are parallel loops.
446   return isa<GenericOp, IndexedGenericOp>(linalgOp.getOperation()) &&
447          linalgOp.hasTensorSemantics() &&
448          llvm::all_of(linalgOp.indexing_maps().getValue(),
449                       [](Attribute attr) {
450                         return attr.cast<AffineMapAttr>()
451                             .getValue()
452                             .isProjectedPermutation();
453                       }) &&
454          linalgOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 &&
455          llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) {
456            return attr.cast<StringAttr>().getValue() ==
457                   getParallelIteratorTypeName();
458          });
459 }
460 
461 namespace {
462 /// Information needed to expand a generic/indexed_generic operation to fold the
463 /// reshape with it.
464 class ExpansionInfo {
465 public:
466   // Computes the mapping from original dimensions of the op to the dimensions
467   // of the expanded op given the `indexingMap` of the fused operand/result of
468   // the generic/indexed_generic op, the `reassocationMaps` of the reshape op
469   // and the shape of the expanded op.
470   LogicalResult compute(LinalgOp linalgOp, unsigned fusedTensorIndex,
471                         ArrayRef<AffineMap> reassociationMaps,
472                         ArrayRef<int64_t> expandedShape);
473   unsigned getOrigOpNumDims() const { return reassociation.size(); }
474   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
475   ReassociationIndicesRef getExpandedDims(unsigned i) const {
476     return reassociation[i];
477   }
478   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
479     return expandedShapeMap[i];
480   }
481 
482 private:
483   /// Reassociation from the dimensions in the original operation to the
484   /// dimension of the expanded operation.
485   SmallVector<ReassociationIndices, 4> reassociation;
486   /// Mapping from extent of loops in the original operation, to the extent of
487   /// loops in the expanded operation.
488   SmallVector<SmallVector<int64_t, 4>, 4> expandedShapeMap;
489   unsigned expandedOpNumDims;
490 };
491 } // namespace
492 
493 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
494                                      unsigned fusedTensorIndex,
495                                      ArrayRef<AffineMap> reassociationMaps,
496                                      ArrayRef<int64_t> expandedShape) {
497   if (reassociationMaps.empty())
498     return failure();
499   AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
500 
501   Optional<SmallVector<int64_t, 4>> originalLoopRange =
502       getStaticLoopRanges(linalgOp);
503   if (!originalLoopRange)
504     return linalgOp.emitError("unable to find loop range for operation");
505 
506   reassociation.clear();
507   expandedShapeMap.clear();
508   // Compute the number of dimension in the expanded op that correspond to each
509   // dimension of the original op.
510   SmallVector<unsigned, 4> numExpandedDims(fusedIndexMap.getNumDims(), 1);
511   expandedShapeMap.resize(fusedIndexMap.getNumDims());
512   for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
513     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
514     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
515     numExpandedDims[pos] = foldedDims.getNumResults();
516     ArrayRef<int64_t> shape =
517         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
518     expandedShapeMap[pos].assign(shape.begin(), shape.end());
519   }
520   // The remaining dimensions remain the same.
521   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
522     if (expandedShapeMap[i].empty())
523       expandedShapeMap[i] = {(*originalLoopRange)[i]};
524 
525   // Compute reassociation map from the original op to the expanded op.
526   unsigned sum = 0;
527   reassociation.reserve(fusedIndexMap.getNumDims());
528   for (auto numFoldedDim : llvm::enumerate(numExpandedDims)) {
529     auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
530     reassociation.emplace_back(seq.begin(), seq.end());
531     sum += numFoldedDim.value();
532   }
533   expandedOpNumDims = sum;
534   return success();
535 }
536 
537 /// To expand an indexed_generic operation, the body of the indexed generic op
538 /// need to be modified appropriately. Specifically, uses of arguments for
539 /// induction variables in the original operation need to be replaced with
540 /// linearization of the corresponding arguments in the expanded op. That
541 /// requires the shape of the expanded dimensions (at least all but the most
542 /// significant. For now check that these are all statically sized. Note that
543 /// this could be extended to handle dynamic case, but the implementation below
544 /// uses `affine.apply` which seems to have issues when the shapes are not
545 /// static.
546 LogicalResult isIndexedGenericOpExpandable(LinalgOp linalgOp,
547                                            const ExpansionInfo &expansionInfo) {
548   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
549     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
550     if (expandedShape.size() == 1)
551       continue;
552     for (int64_t shape : expandedShape.drop_front()) {
553       if (ShapedType::isDynamic(shape)) {
554         return linalgOp.emitError(
555             "unable to fuse indexed generic op where the expanded dim is "
556             "dynamic");
557       }
558     }
559   }
560   return success();
561 }
562 
563 /// Return the indexing map to use in the expanded op for a given the
564 /// `indexingMap` of the original operation.
565 static AffineMap
566 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
567                            const ExpansionInfo &expansionInfo) {
568   SmallVector<AffineExpr, 4> newExprs;
569   for (AffineExpr expr : indexingMap.getResults()) {
570     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
571     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
572         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
573           return builder.getAffineDimExpr(static_cast<unsigned>(v));
574         }));
575     newExprs.append(expandedExprs.begin(), expandedExprs.end());
576   }
577   return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
578                         indexingMap.getNumSymbols(), newExprs,
579                         builder.getContext());
580 }
581 
582 /// Return the type of the operand/result to use in the expanded op given the
583 /// type in the original op.
584 static RankedTensorType getExpandedType(RankedTensorType originalType,
585                                         AffineMap indexingMap,
586                                         const ExpansionInfo &expansionInfo) {
587   SmallVector<int64_t, 4> expandedShape;
588   for (AffineExpr expr : indexingMap.getResults()) {
589     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
590     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
591     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
592   }
593   return RankedTensorType::get(expandedShape, originalType.getElementType());
594 }
595 
596 /// Returns the reassociation maps to use in the `linalg.tensor_reshape`
597 /// operation to convert the operands of the origial operation to operands of
598 /// the expanded operation. The same method is used to compute the
599 /// `linalg.tensor_reshape` used to collapse the result of the expanded op to
600 /// get the value that can replace all uses of the results of the original op.
601 static SmallVector<ReassociationIndices, 4>
602 getReassociationForExpansion(AffineMap indexingMap,
603                              const ExpansionInfo &expansionInfo) {
604   SmallVector<ReassociationIndices, 4> reassociation;
605   unsigned numReshapeDims = 0;
606   for (AffineExpr expr : indexingMap.getResults()) {
607     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
608     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
609     auto indices = llvm::to_vector<2>(
610         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
611     reassociation.emplace_back(std::move(indices));
612     numReshapeDims += numExpandedDims;
613   }
614   return reassociation;
615 }
616 
617 /// Build the body of the expanded IndexedGenericOp. The arguments for the
618 /// induction variables of the original operation need to be recovered by
619 /// linearizing the arguments of the corresponding dimensions of the expanded
620 /// op. For now it is assumed that the shapes of the expanded op needed for
621 /// linearization are static.
622 static void buildExpandedIndexedGenericOpRegion(
623     PatternRewriter &rewriter, Location loc, Region &originalOpRegion,
624     Region &fusedOpRegion, const ExpansionInfo &expansionInfo) {
625   assert(fusedOpRegion.empty() && "expected fused op to have empty region");
626   // Create an entry block in the fused region with same number of arguments
627   // as the fused op
628   Block *fusedEntryBlock = new Block;
629   fusedOpRegion.push_back(fusedEntryBlock);
630   rewriter.cloneRegionBefore(originalOpRegion, fusedOpRegion,
631                              fusedOpRegion.end());
632 
633   // Merge the entry block of the fused op with the cloned blocks. For this
634   // compute the value for arguments of the region in the original operation
635   // in terms of the arguments of the fused op. Since the original operation
636   // is expanded, the expanded dimensions need to be folded back to get the
637   // replacement value for the arguments corresponding to interation index.
638   // For now this expects that all the loop ranges are constants, which is
639   // true if the shapes are all static. This has already been checked in the
640   // precondition.
641   using namespace edsc::op;
642   using namespace edsc::intrinsics;
643   OpBuilder::InsertionGuard guard(rewriter);
644   SmallVector<Value, 4> argReplacements(originalOpRegion.getNumArguments());
645   rewriter.setInsertionPointToStart(fusedEntryBlock);
646   edsc::ScopedContext scopedContext(rewriter, loc);
647   IndexType indexType = rewriter.getIndexType();
648   for (auto i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
649     Value linearizedIndex = fusedEntryBlock->addArgument(indexType);
650     ArrayRef<int64_t> expandedDimsShape =
651         expansionInfo.getExpandedShapeOfDim(i).drop_front();
652     for (unsigned shape : expandedDimsShape) {
653       assert(!ShapedType::isDynamic(shape));
654       linearizedIndex = linearizedIndex * std_constant_index(shape);
655       linearizedIndex =
656           linearizedIndex + fusedEntryBlock->addArgument(indexType);
657     }
658     argReplacements[i] = linearizedIndex;
659   }
660   for (auto i : llvm::seq<unsigned>(expansionInfo.getOrigOpNumDims(),
661                                     argReplacements.size())) {
662     argReplacements[i] =
663         fusedEntryBlock->addArgument(originalOpRegion.getArgument(i).getType());
664   }
665   rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock,
666                        argReplacements);
667 }
668 
669 /// Implements the fusion of a tensor_reshape op and a generic/indexed_generic
670 /// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those
671 /// conditions have been satisfied.
672 static Optional<SmallVector<Value, 1>>
673 fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
674                            unsigned fusedTensorIndex,
675                            PatternRewriter &rewriter) {
676   assert(isFusableWithReshapeByDimExpansion(linalgOp, fusedTensorIndex) &&
677          "preconditions for fuse operation failed");
678   // Check if reshape is expanding or collapsing.
679   bool isExpanding =
680       reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
681   RankedTensorType expandedType =
682       isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
683 
684   ExpansionInfo expansionInfo;
685   if (failed(expansionInfo.compute(linalgOp, fusedTensorIndex,
686                                    reshapeOp.getReassociationMaps(),
687                                    expandedType.getShape())))
688     return llvm::None;
689 
690   if (isa<IndexedGenericOp>(linalgOp.getOperation()) &&
691       failed(isIndexedGenericOpExpandable(linalgOp, expansionInfo)))
692     return llvm::None;
693 
694   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
695       llvm::map_range(linalgOp.getIndexingMaps(), [&](AffineMap m) {
696         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
697       }));
698 
699   SmallVector<Value, 4> expandedOpOperands;
700   for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
701     if (operand.index() == fusedTensorIndex) {
702       expandedOpOperands.push_back(reshapeOp.src());
703       continue;
704     }
705     AffineMap indexingMap = linalgOp.getInputIndexingMap(operand.index());
706     RankedTensorType expandedOperandType =
707         getExpandedType(operand.value().getType().cast<RankedTensorType>(),
708                         indexingMap, expansionInfo);
709     if (expandedOperandType != operand.value().getType()) {
710       // Reshape the operand to get the right type.
711       SmallVector<ReassociationIndices, 4> reassociation =
712           getReassociationForExpansion(indexingMap, expansionInfo);
713       expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>(
714           linalgOp.getLoc(), expandedOperandType, operand.value(),
715           reassociation));
716       continue;
717     }
718     expandedOpOperands.push_back(operand.value());
719   }
720 
721   Location loc = linalgOp.getLoc();
722   SmallVector<Value, 1> outputs;
723   for (auto result : llvm::enumerate(linalgOp.getOutputs())) {
724     AffineMap indexingMap = linalgOp.getOutputIndexingMap(result.index());
725     RankedTensorType expandedOutputType =
726         getExpandedType(result.value().getType().cast<RankedTensorType>(),
727                         indexingMap, expansionInfo);
728     if (expandedOutputType != result.value().getType()) {
729       SmallVector<ReassociationIndices, 4> reassociation =
730           getReassociationForExpansion(indexingMap, expansionInfo);
731       outputs.push_back(rewriter.create<TensorReshapeOp>(
732           linalgOp.getLoc(), expandedOutputType, result.value(),
733           reassociation));
734     }
735   }
736 
737   // The iterator types of the expanded op are all parallel.
738   SmallVector<StringRef, 4> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
739                                           getParallelIteratorTypeName());
740 
741   TypeRange resultTypes = ValueRange(outputs).getTypes();
742   LinalgOp fusedOp = createLinalgOpOfSameType(
743       linalgOp, rewriter, linalgOp.getLoc(), resultTypes,
744       /*inputs=*/expandedOpOperands, outputs, expandedOpIndexingMaps,
745       iteratorTypes);
746   Region &fusedRegion = fusedOp->getRegion(0);
747   Region &originalRegion = linalgOp->getRegion(0);
748 
749   if (isa<GenericOp>(linalgOp.getOperation())) {
750     rewriter.cloneRegionBefore(originalRegion, fusedRegion,
751                                fusedRegion.begin());
752   } else {
753     assert(isa<IndexedGenericOp>(linalgOp.getOperation()));
754     buildExpandedIndexedGenericOpRegion(rewriter, loc, originalRegion,
755                                         fusedRegion, expansionInfo);
756   }
757 
758   // Reshape the result values to their original shape if this is a collapsing
759   // reshape folded into its consumer.
760   SmallVector<Value, 1> resultVals;
761   for (auto result : llvm::enumerate(linalgOp->getResults())) {
762     if (!isExpanding &&
763         resultTypes[result.index()] != result.value().getType()) {
764       SmallVector<ReassociationIndices, 4> reassociation =
765           getReassociationForExpansion(
766               linalgOp.getOutputIndexingMap(result.index()), expansionInfo);
767       resultVals.push_back(rewriter.create<TensorReshapeOp>(
768           linalgOp.getLoc(), result.value().getType(),
769           fusedOp->getResult(result.index()), reassociation));
770     } else {
771       resultVals.push_back(fusedOp->getResult(result.index()));
772     }
773   }
774   // Assuming a single result.
775   return resultVals;
776 }
777 
778 namespace {
779 
780 /// Pattern to fold tensor_reshape op with its consumer by using the source of
781 /// the reshape op as the operand in the consumer (instead of the result of the
782 /// tensor_reshapeop) when the tensor_reshape op is collapsing. The
783 /// corresponding index map in the consumer needs to be modified to linearize
784 /// the folded dimension.
785 ///
786 /// For example,
787 ///
788 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
789 /// %0 = linalg.tensor_reshape %arg0
790 ///        [affine_map<(i, j, k, l) -> (i)>, affine_map<(i, j, k, l) -> (j, k)>,
791 ///         affine_map<(i, j, k, l) -> (l)>]
792 ///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
793 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
794 ///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
795 ///        -> tensor<?x?x4x?xf32>
796 ///
797 /// can be folded into
798 ///
799 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
800 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
801 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
802 ///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
803 ///        -> tensor<?x?x4x?xf32>
804 template <typename LinalgOpTy, bool foldUnitDimReshapesOnly>
805 struct FoldProducerReshapeOpByLinearization
806     : public OpRewritePattern<LinalgOpTy> {
807   using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
808 
809   LogicalResult matchAndRewrite(LinalgOpTy op,
810                                 PatternRewriter &rewriter) const override {
811     if (!op.hasTensorSemantics())
812       return failure();
813     LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
814     for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
815       TensorReshapeOp reshapeOp =
816           operand.value().getDefiningOp<TensorReshapeOp>();
817       if (!reshapeOp ||
818           !isTensorReshapeOpFoldableByLinearization(
819               reshapeOp, linalgOp.getInputIndexingMap(operand.index()),
820               /*asProducer =*/true) ||
821           (foldUnitDimReshapesOnly &&
822            !isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
823                                    reshapeOp.getReassociationMaps())))
824         continue;
825 
826       // Compute the fused operands list,
827       SmallVector<Value, 2> fusedOperands(linalgOp.getInputs());
828       fusedOperands[operand.index()] = reshapeOp.src();
829       fusedOperands.append(linalgOp.getOutputs().begin(),
830                            linalgOp.getOutputs().end());
831 
832       // Compute indexing_maps for the fused operation. The indexing_maps for
833       // the operands of the consumers that arent fused are the same.
834       SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
835           op.indexing_maps().template getAsValueRange<AffineMapAttr>());
836 
837       // Accepted consumer maps are either identity or permutation.
838       auto invMap = inversePermutation(fusedIndexMaps[operand.index()]);
839 
840       // Compute the indexing map to use for the result of the producer.
841       AffineMap modifiedMap =
842           linearizeCollapsedDims(invMap, reshapeOp.getResultType().getShape(),
843                                  reshapeOp.getReassociationMaps());
844       for (AffineExpr expr : modifiedMap.getResults()) {
845         if (!expr.isPureAffine())
846           return failure();
847       }
848       fusedIndexMaps[operand.index()] = modifiedMap;
849 
850       // Further check that the resulting index maps can be fused and
851       // inverted. Without this the resultant op is not legal.
852       if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
853         return op.emitRemark("fused op loop bound computation failed");
854 
855       rewriter.startRootUpdate(op);
856       op->setOperands(fusedOperands);
857       op.indexing_mapsAttr(rewriter.getAffineMapArrayAttr(fusedIndexMaps));
858       rewriter.finalizeRootUpdate(op);
859       if (reshapeOp.use_empty())
860         rewriter.eraseOp(reshapeOp);
861       return success();
862     }
863     return failure();
864   }
865 };
866 
867 /// Pattern to fuse a tensor_reshape op with its consumer
868 /// generic/indexed_generic op, when the reshape op is collapsing
869 /// dimensions. The dimensionality of the loop in the consumer is expanded.
870 template <typename GenericOpTy>
871 struct FoldWithProducerReshapeOpByExpansion
872     : public OpRewritePattern<GenericOpTy> {
873   using OpRewritePattern<GenericOpTy>::OpRewritePattern;
874 
875   LogicalResult matchAndRewrite(GenericOpTy genericOp,
876                                 PatternRewriter &rewriter) const override {
877     LinalgOp linalgOp = cast<LinalgOp>(genericOp.getOperation());
878     for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
879       TensorReshapeOp reshapeOp =
880           operand.value().getDefiningOp<TensorReshapeOp>();
881       if (!reshapeOp)
882         continue;
883 
884       // Fold only if
885       // - The tensor reshape op is folding.
886       // - All constraints of fusing with reshape by expansion are met.
887       if (reshapeOp.getSrcType().getRank() <
888               reshapeOp.getResultType().getRank() ||
889           !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) ||
890           isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
891                                  reshapeOp.getReassociationMaps()))
892         continue;
893 
894       Optional<SmallVector<Value, 1>> replacementValues =
895           fuseWithReshapeByExpansion(linalgOp, reshapeOp, operand.index(),
896                                      rewriter);
897       if (!replacementValues)
898         return failure();
899       rewriter.replaceOp(genericOp, replacementValues.getValue());
900       if (reshapeOp.use_empty())
901         rewriter.eraseOp(reshapeOp);
902       return success();
903     }
904     return failure();
905   }
906 };
907 
908 /// Pattern to fold tensor_reshape op with its producer. The corresponding index
909 /// map in the consumer needs to be modified to linearize the folded dimension.
910 template <bool foldUnitDimReshapesOnly>
911 struct FoldConsumerReshapeOpByLinearization
912     : public OpRewritePattern<TensorReshapeOp> {
913   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
914 
915   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
916                                 PatternRewriter &rewriter) const override {
917     LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>();
918     if (!producer ||
919         !isa<GenericOp, IndexedGenericOp>(producer.getOperation()) ||
920         !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 ||
921         !isTensorReshapeOpFoldableByLinearization(
922             reshapeOp, producer.getOutputIndexingMap(0),
923             /*asProducer =*/false) ||
924         (foldUnitDimReshapesOnly &&
925          !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
926                                  reshapeOp.getReassociationMaps())))
927       return failure();
928     // The indexing_maps for the operands of the fused operation are same as
929     // those for the operands of the producer.
930     SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
931         producer.indexing_maps().getAsValueRange<AffineMapAttr>());
932 
933     auto invMap = inversePermutation(producer.getOutputIndexingMap(0));
934 
935     // Compute the indexing map to use for the operand of the producer.
936     AffineMap modifiedMap =
937         linearizeCollapsedDims(invMap, reshapeOp.getSrcType().getShape(),
938                                reshapeOp.getReassociationMaps());
939     for (AffineExpr expr : modifiedMap.getResults()) {
940       if (!expr.isPureAffine())
941         return producer.emitRemark("fused op indexing map is not affine");
942     }
943     fusedIndexMaps.back() = modifiedMap;
944 
945     // Further check that the resulting index maps can be fused and
946     // inverted. Without this the resultant op is not legal.
947     if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
948       return reshapeOp.emitRemark("fused op loop bound computation failed");
949 
950     Location loc = producer.getLoc();
951     Value output = rewriter.create<TensorReshapeOp>(
952         loc, producer.getOutputs()[0], reshapeOp.getReassociationExprs());
953     LinalgOp fusedOp = createLinalgOpOfSameType(
954         producer, rewriter, loc, reshapeOp.getResultType(),
955         /*inputs=*/producer.getInputs(),
956         // TODO: handle outputs.
957         /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
958         producer.iterator_types(),
959         /*doc=*/nullptr,
960         /*library_call=*/nullptr,
961         /*sparse=*/nullptr);
962     auto &fusedRegion = fusedOp->getRegion(0);
963     rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
964                                fusedRegion.begin());
965     rewriter.replaceOp(reshapeOp, fusedOp->getResults());
966     if (producer.use_empty())
967       rewriter.eraseOp(producer);
968     return success();
969   }
970 };
971 
972 /// Pattern to fold a tensor_reshape op with its producer generic op if the
973 /// tensor_reshape op is expanding, by expanding the dimensionality of the loop
974 /// in the producer op.
975 struct FoldReshapeWithGenericOpByExpansion
976     : public OpRewritePattern<TensorReshapeOp> {
977   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
978   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
979                                 PatternRewriter &rewriter) const override {
980     // Fold only if
981     // - The tensor reshape op is a expanding case.
982     // - All constraints of fusing with reshape by expansion are met.
983     if (reshapeOp.getSrcType().getRank() > reshapeOp.getResultType().getRank())
984       return failure();
985     LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>();
986     if (!producer || producer.getNumOutputs() != 1 ||
987         !isFusableWithReshapeByDimExpansion(producer,
988                                             producer.getNumInputs()) ||
989         isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
990                                reshapeOp.getReassociationMaps()))
991       return failure();
992     Optional<SmallVector<Value, 1>> replacementValues =
993         fuseWithReshapeByExpansion(producer, reshapeOp, producer.getNumInputs(),
994                                    rewriter);
995     if (!replacementValues)
996       return failure();
997     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
998     if (producer.use_empty())
999       rewriter.eraseOp(producer);
1000     return success();
1001   }
1002 };
1003 
1004 /// Pattern to fold a GenericOp/IndexedGenericOp with a splat constant.
1005 template <typename LinalgOpTy>
1006 struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
1007   using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
1008 
1009   LogicalResult matchAndRewrite(LinalgOpTy op,
1010                                 PatternRewriter &rewriter) const override {
1011     if (!op.hasTensorSemantics())
1012       return failure();
1013     LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
1014     for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
1015       ConstantOp constantOp = operand.value().getDefiningOp<ConstantOp>();
1016       if (!constantOp ||
1017           !constantOp.value().cast<DenseElementsAttr>().isSplat())
1018         continue;
1019 
1020       // The indexing_maps for the operands of the fused operation are same as
1021       // those for the operands of the linalgOp without the indexing map at
1022       // operand.index()
1023       SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
1024           linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
1025       fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index()));
1026 
1027       // The operands list is same as the linalgOp with the argument for
1028       // constant index dropped.
1029       SmallVector<Value, 4> fusedOperands(linalgOp.getInputs());
1030       fusedOperands.erase(std::next(fusedOperands.begin(), operand.index()));
1031 
1032       // Create a constant scalar value from the splat constant.
1033       Value scalarConstant = rewriter.create<ConstantOp>(
1034           constantOp.getLoc(),
1035           constantOp.value().cast<DenseElementsAttr>().getSplatValue());
1036 
1037       LinalgOp fusedOp = createLinalgOpOfSameType(
1038           linalgOp, rewriter, rewriter.getUnknownLoc(),
1039           linalgOp->getResultTypes(),
1040           /*inputs=*/fusedOperands,
1041           /*outputs=*/linalgOp.getOutputs(),
1042           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1043           linalgOp.iterator_types(),
1044           /*doc=*/nullptr,
1045           /*library_call=*/nullptr,
1046           /*sparse=*/nullptr);
1047 
1048       // Map the block argument corresponding to the replaced argument with the
1049       // scalar constant.
1050       Region &linalgOpRegion = linalgOp->getRegion(0);
1051       Block &entryBlock = *linalgOpRegion.begin();
1052       unsigned argIndex = entryBlock.getNumArguments() -
1053                           linalgOp.getNumShapedOperands() + operand.index();
1054       BlockAndValueMapping mapping;
1055       mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
1056       Region &fusedRegion = fusedOp->getRegion(0);
1057       rewriter.cloneRegionBefore(linalgOpRegion, fusedRegion,
1058                                  fusedRegion.begin(), mapping);
1059       rewriter.replaceOp(linalgOp, fusedOp->getResults());
1060       if (constantOp.use_empty())
1061         rewriter.eraseOp(constantOp);
1062       return success();
1063     }
1064     return failure();
1065   }
1066 };
1067 } // namespace
1068 
1069 Optional<SmallVector<Value, 1>>
1070 mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
1071                             OpOperand &consumerOpOperand) {
1072   Operation *producer = consumerOpOperand.get().getDefiningOp();
1073   if (!producer || producer->getNumResults() != 1)
1074     return llvm::None;
1075 
1076   // Fuse when consumer is GenericOp or IndexedGenericOp.
1077   if (!isa<GenericOp, IndexedGenericOp>(consumerOpOperand.getOwner()) ||
1078       !isa<GenericOp, IndexedGenericOp>(producer))
1079     return llvm::None;
1080 
1081   return fuseTensorOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
1082                            rewriter);
1083 }
1084 
1085 namespace {
1086 /// Patterns to fuse a generic op, with the producer of its operands.
1087 template <typename LinalgOpTy>
1088 struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
1089   using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
1090 
1091   LogicalResult matchAndRewrite(LinalgOpTy op,
1092                                 PatternRewriter &rewriter) const override {
1093     // Find the first operand that is defined by another generic op on tensors.
1094     for (OpOperand &opOperand : op.getShapedOpOperands()) {
1095       Operation *producer = opOperand.get().getDefiningOp();
1096       if (!producer)
1097         continue;
1098       Optional<SmallVector<Value, 1>> fusedOpResults =
1099           fuseTensorOps(rewriter, opOperand);
1100       if (fusedOpResults) {
1101         rewriter.replaceOp(op, *fusedOpResults);
1102         if (producer->use_empty())
1103           rewriter.eraseOp(producer);
1104         return success();
1105       }
1106     }
1107     return failure();
1108   }
1109 };
1110 
1111 /// Pass that fuses generic ops on tensors. Used only for testing.
1112 struct FusionOfTensorOpsPass
1113     : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
1114   void runOnOperation() override {
1115     OwningRewritePatternList patterns;
1116     Operation *op = getOperation();
1117     populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
1118     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
1119   }
1120 };
1121 
1122 /// Pass to test folding of reshape op with generic/indexed_generic ops by
1123 /// linearization.
1124 struct FoldReshapeOpsByLinearizationPass
1125     : public LinalgFoldReshapeOpsByLinearizationBase<
1126           FoldReshapeOpsByLinearizationPass> {
1127   void runOnOperation() override {
1128     OwningRewritePatternList patterns;
1129     Operation *op = getOperation();
1130     populateFoldReshapeOpsByLinearizationPatterns(op->getContext(), patterns);
1131     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
1132   }
1133 };
1134 
1135 } // namespace
1136 
1137 void mlir::populateFoldReshapeOpsByLinearizationPatterns(
1138     MLIRContext *context, OwningRewritePatternList &patterns) {
1139   patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, false>,
1140                   FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
1141                   FoldConsumerReshapeOpByLinearization<false>>(context);
1142 }
1143 
1144 void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
1145     MLIRContext *context, OwningRewritePatternList &patterns) {
1146   patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, true>,
1147                   FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
1148                   FoldConsumerReshapeOpByLinearization<true>>(context);
1149 }
1150 
1151 void mlir::populateFoldReshapeOpsByExpansionPatterns(
1152     MLIRContext *context, OwningRewritePatternList &patterns) {
1153   patterns.insert<FoldReshapeWithGenericOpByExpansion,
1154                   FoldWithProducerReshapeOpByExpansion<GenericOp>,
1155                   FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
1156       context);
1157 }
1158 
1159 void mlir::populateLinalgTensorOpsFusionPatterns(
1160     MLIRContext *context, OwningRewritePatternList &patterns) {
1161   patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
1162                   FoldSplatConstants<GenericOp>,
1163                   FoldSplatConstants<IndexedGenericOp>>(context);
1164   populateFoldReshapeOpsByExpansionPatterns(context, patterns);
1165   GenericOp::getCanonicalizationPatterns(patterns, context);
1166   IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
1167   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
1168 }
1169 
1170 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
1171   return std::make_unique<FusionOfTensorOpsPass>();
1172 }
1173 
1174 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
1175   return std::make_unique<FoldReshapeOpsByLinearizationPass>();
1176 }
1177