1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include <utility>
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Support/LLVM.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 
30 //===---------------------------------------------------------------------===//
31 // Methods and patterns that fuse elementwise `linalg.generic` operations.
32 //===---------------------------------------------------------------------===//
33 
34 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
35 /// the `producer` to use in the fused operation given the indexing map of the
36 /// result of the producer in the consumer.
37 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
38     OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
39     AffineMap fusedConsumerArgIndexMap) {
40   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
41   // from consumer loop -> consumer arg tensor index/producer result tensor
42   // index. The fused loop is same as the consumer loop. For each producer arg
43   // the indexing map to be computed is a map from consumer loop -> producer
44   // arg tensor index.
45   // producerResultIndexMap is a map from producer loop -> tensor index.
46   // Compute the inverse to get map from tensor index -> producer loop.
47   // The inverse is a map from producer result tensor index -> producer loop.
48   AffineMap invProducerResultIndexMap =
49       inversePermutation(producerResultIndexMap);
50   assert(invProducerResultIndexMap &&
51          "expected producer result indexig map to be invertible");
52 
53   LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
54   // argMap is a map from producer loop -> producer arg tensor index.
55   AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
56 
57   // Compose argMap with invProducerResultIndexMap to get a map from
58   // producer result tensor index -> producer arg tensor index.
59   AffineMap t1 = argMap.compose(invProducerResultIndexMap);
60 
61   // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
62   // consumer loop/ fused loop -> producer arg tensor index.
63   return t1.compose(fusedConsumerArgIndexMap);
64 }
65 
66 /// Conditions for elementwise fusion of generic operations.
67 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
68                                      OpOperand *consumerOpOperand) {
69   // Producer and consumer must have tensor semantics.
70   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
71     return false;
72 
73   // Verify that
74   // - the producer has all "parallel" iterator type.
75   if (producer.getNumParallelLoops() != producer.getNumLoops())
76     return false;
77 
78   // Only allow fusing the producer of an input operand for now.
79   // TODO: allow fusing the producer of an output operand.
80   if (!consumer.isInputTensor(consumerOpOperand))
81     return false;
82 
83   // Get the consumer index map. The number of results of the consumer index
84   // map must match the number of loops of the producer.
85   AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
86   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
87     return false;
88 
89   // Currently support only operations with single result.
90   if (producer.getNumOutputs() != 1)
91     return false;
92 
93   // Finally the index_map for the result must be invertible. For now just
94   // verify it is a permutation.
95   AffineMap producerResultIndexMap =
96       producer.getTiedIndexingMap(producer.getOutputOperand(0));
97   if (!producerResultIndexMap.isPermutation())
98     return false;
99 
100   // Ensure that the fusion does not remove size information required to
101   // get the loop bounds. For non-reduction generics, this is trivially the
102   // case due to the output operand. For reductions, we need to check that after
103   // the fusion, each loop dimension has at least one input that defines it.
104   if ((consumer.getNumReductionLoops())) {
105     BitVector coveredDims(consumer.getNumLoops(), false);
106 
107     auto addToCoveredDims = [&](AffineMap map) {
108       for (auto result : map.getResults())
109         if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
110           coveredDims[dimExpr.getPosition()] = true;
111     };
112 
113     for (auto pair :
114          llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) {
115       Value operand = std::get<0>(pair);
116       if (operand == consumerOpOperand->get())
117         continue;
118       AffineMap operandMap = std::get<1>(pair);
119       addToCoveredDims(operandMap);
120     }
121 
122     for (OpOperand *operand : producer.getInputOperands()) {
123       AffineMap newIndexingMap =
124           getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
125               operand, producerResultIndexMap, consumerIndexMap);
126       addToCoveredDims(newIndexingMap);
127     }
128     if (!coveredDims.all())
129       return false;
130   }
131 
132   return true;
133 }
134 
135 /// Generate the region of the fused tensor operation. The region of the fused
136 /// op must be empty.
137 static void
138 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
139                                  AffineMap consumerToProducerLoopsMap,
140                                  OpOperand *consumerOpOperand,
141                                  unsigned nloops) {
142   auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
143   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
144   // Build the region of the fused op.
145   Block &producerBlock = producer->getRegion(0).front();
146   Block &consumerBlock = consumer->getRegion(0).front();
147   Block *fusedBlock = new Block();
148   fusedOp.region().push_back(fusedBlock);
149   BlockAndValueMapping mapper;
150   OpBuilder::InsertionGuard guard(rewriter);
151   rewriter.setInsertionPointToStart(fusedBlock);
152 
153   // 2. Add an index operation for every fused loop dimension and use the
154   // `consumerToProducerLoopsMap` to map the producer indices.
155   if (producer.hasIndexSemantics()) {
156     // Add an index operation for every fused loop dimension.
157     unsigned numFusedOpLoops =
158         std::max(producer.getNumLoops(), consumer.getNumLoops());
159     SmallVector<Value> fusedIndices;
160     fusedIndices.reserve(numFusedOpLoops);
161     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
162                     std::back_inserter(fusedIndices), [&](uint64_t dim) {
163                       return rewriter.create<IndexOp>(producer.getLoc(), dim);
164                     });
165     for (IndexOp indexOp :
166          llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
167       Value newIndex = rewriter.create<mlir::AffineApplyOp>(
168           producer.getLoc(),
169           consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
170       mapper.map(indexOp.getResult(), newIndex);
171     }
172   }
173   // TODO: allow fusing the producer of an output operand.
174   assert(consumer.isInputTensor(consumerOpOperand) &&
175          "expected producer of input operand");
176   // 3. Consumer input operands up to consumerIdx (exclusive).
177   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
178            consumerOpOperand->getOperandNumber())) // input assumption.
179     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
180 
181   // Replacing consumerIdx requires getting the cloned, yielded, value from
182   // the (cloned) producer block. This happens in step 9.
183 
184   // 4. Splice in producer's input operands.
185   for (BlockArgument bbArg :
186        producerBlock.getArguments().take_front(producer.getNumInputs()))
187     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
188 
189   // 4.b. Producer output operand/map that is fused needs to be mapped to the
190   // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
191   assert(producer->getNumResults() == 1 && "expected single result producer");
192   if (producer.isInitTensor(producer.getOutputOperand(0))) {
193     BlockArgument bbArg = producerBlock.getArguments()
194                               .drop_front(producer.getNumInputs())
195                               // TODO: bbArg index of
196                               .front();
197     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
198   }
199   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
200   for (BlockArgument bbArg :
201        consumerBlock.getArguments()
202            .take_front(consumer.getNumInputs())
203            .drop_front(consumerOpOperand->getOperandNumber() + 1))
204     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
205   // 6. All of consumer's output operands.
206   for (BlockArgument bbArg :
207        consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
208     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
209   // 7. All of producer's output operands except the one fused.
210   // TODO: allow fusion of multi-result producers.
211   assert(producer->getNumResults() == 1 && "expected single result producer");
212 
213   // 8. Clone all producer operations except for the yield and index operations
214   // to the fused operation.
215   for (auto &op : producerBlock.without_terminator()) {
216     if (!isa<IndexOp>(op))
217       rewriter.clone(op, mapper);
218   }
219   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
220   // forward the yield operand.
221   auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
222   // TODO: allow fusion of multi-result producers.
223   assert(producer->getNumResults() == 1 && "expected single result producer");
224   unsigned producerResultNumber = 0;
225   Value replacement =
226       mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
227   // Sanity checks, if replacement is not already in the mapper then it must be
228   // produced outside.
229   if (replacement == yieldOp.getOperand(producerResultNumber)) {
230     if (auto bb = replacement.dyn_cast<BlockArgument>())
231       assert(bb.getOwner() != &producerBlock &&
232              "yielded block argument must have been mapped");
233     else
234       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
235              "yielded value must have been mapped");
236   }
237   mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
238              replacement);
239   // 10. Clone operations from the consumer to the fused op.
240   for (auto &op : consumerBlock.getOperations())
241     rewriter.clone(op, mapper);
242 
243   // Sanity checks.
244   assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
245          "Ill-formed GenericOp region");
246 }
247 
248 static Optional<SmallVector<Value>>
249 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
250                        const ControlElementwiseOpsFusionFn &controlFn,
251                        PatternRewriter &rewriter) {
252   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
253   if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
254       !controlFn(producer->getResult(0), *consumerOpOperand))
255     return llvm::None;
256 
257   // TODO: allow fusing the producer of an output operand.
258   assert(consumer.isInputTensor(consumerOpOperand) &&
259          "expected producer of input operand");
260 
261   // Compute the fused operands list and indexing maps.
262   SmallVector<Value> fusedOperands;
263   SmallVector<AffineMap> fusedIndexMaps;
264   fusedOperands.reserve(producer->getNumOperands() +
265                         consumer->getNumOperands());
266   fusedIndexMaps.reserve(producer->getNumOperands() +
267                          consumer->getNumOperands());
268   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
269   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
270   SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
271   SmallVector<OpOperand *>::iterator it =
272       llvm::find(consumerInputs, consumerOpOperand);
273   assert(it != consumerInputs.end() && "expected to find the consumer operand");
274   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
275     fusedOperands.push_back(opOperand->get());
276     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
277   }
278   // 4. Splice in producer's input operands/maps.
279   assert(producer->getNumResults() == 1 && "expected single result producer");
280   AffineMap producerResultIndexMap =
281       producer.getTiedIndexingMap(producer.getOutputOperand(0));
282   for (OpOperand *opOperand : producer.getInputOperands()) {
283     fusedOperands.push_back(opOperand->get());
284     // Compute indexing maps for the producer args in the fused operation.
285     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
286         opOperand, producerResultIndexMap,
287         consumer.getTiedIndexingMap(consumerOpOperand));
288     fusedIndexMaps.push_back(map);
289   }
290   // 4.b. Producer output operand/map that is fused needs to be passed if it is
291   // an "initTensor" (i.e. its value is actually read).
292   assert(producer->getNumResults() == 1 && "expected single result producer");
293   if (producer.isInitTensor(producer.getOutputOperand(0))) {
294     fusedOperands.push_back(producer.getOutputOperand(0)->get());
295     // Compute indexing maps for the producer args in the fused operation.
296     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
297         producer.getOutputOperand(0), producerResultIndexMap,
298         consumer.getTiedIndexingMap(consumerOpOperand));
299     fusedIndexMaps.push_back(map);
300   }
301   // 5. Remaining consumer's input operands/maps (drop past index
302   // `consumerIdx`).
303   for (OpOperand *opOperand :
304        llvm::make_range(std::next(it), consumerInputs.end())) {
305     fusedOperands.push_back(opOperand->get());
306     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
307   }
308   // 6. All of consumer's output operands (skip operands: added by the builder).
309   for (OpOperand *opOperand : consumer.getOutputOperands())
310     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
311   // 7. All of producer's output operands/maps except the one fused.
312   // TODO: allow fusion of multi-result producers.
313   assert(producer->getNumResults() == 1 && "expected single result producer");
314 
315   // Generate the fused op.
316   SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
317   auto fusedOp = rewriter.create<GenericOp>(
318       consumer.getLoc(), consumer->getResultTypes(),
319       /*inputs=*/fusedOperands,
320       // TODO: handle outputs.
321       consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
322       consumer.iterator_types(),
323       /*doc=*/nullptr,
324       /*library_call=*/nullptr);
325   if (!fusedOp.getShapesToLoopsMap()) {
326     // Fused op has invalid indexing maps. Typically this means something is off
327     // in the input, but going ahead here would result in verification errors.
328     // So cleanup and abort.
329     rewriter.eraseOp(fusedOp);
330     return llvm::None;
331   }
332 
333   // Construct an AffineMap from consumer loops to producer loops.
334   // consumer loop -> tensor index
335   AffineMap consumerResultIndexMap =
336       consumer.getTiedIndexingMap(consumerOpOperand);
337   // tensor index -> producer loop
338   AffineMap invProducerResultIndexMap =
339       inversePermutation(producerResultIndexMap);
340   assert(invProducerResultIndexMap &&
341          "expected producer result indexig map to be invertible");
342   // consumer loop -> producer loop
343   AffineMap consumerToProducerLoopsMap =
344       invProducerResultIndexMap.compose(consumerResultIndexMap);
345 
346   generateFusedElementwiseOpRegion(rewriter, fusedOp,
347                                    consumerToProducerLoopsMap,
348                                    consumerOpOperand, consumer.getNumLoops());
349   return SmallVector<Value>(fusedOp->getResults());
350 }
351 
352 static Optional<SmallVector<Value>>
353 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
354                    GenericOp producer,
355                    const ControlElementwiseOpsFusionFn &controlFn) {
356   if (producer->getNumResults() != 1)
357     return llvm::None;
358 
359   return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
360                                 rewriter);
361 }
362 
363 namespace {
364 /// Patterns to fuse a generic op, with the producer of its operands.
365 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
366 public:
367   FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
368                      PatternBenefit benefit = 1)
369       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
370 
371   LogicalResult matchAndRewrite(GenericOp genericOp,
372                                 PatternRewriter &rewriter) const override {
373     // Find the first operand that is defined by another generic op on tensors.
374     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
375       auto producer =
376           dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
377       if (!producer || !producer.hasTensorSemantics())
378         continue;
379       Optional<SmallVector<Value>> fusedOpResults =
380           fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
381       if (fusedOpResults) {
382         rewriter.replaceOp(genericOp, *fusedOpResults);
383         return success();
384       }
385     }
386     return failure();
387   }
388 
389 private:
390   ControlElementwiseOpsFusionFn controlFn;
391 };
392 } // namespace
393 
394 //===---------------------------------------------------------------------===//
395 // Methods and patterns that fuse reshape ops with elementwise operations by
396 // linearization of indexing maps.
397 //===---------------------------------------------------------------------===//
398 
399 // TODO(ravishankarm): These patterns need to be deprecated. The indexing maps
400 // these produce in the general case are detrimental to transformations.
401 // They are useful now only in the limited case of unit-dimension folding.
402 // Remove these in favor of more general folding by dimension contraction.
403 
404 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
405 /// provided, given the shape of the source tensor that corresponds to the
406 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
407 /// are "row-major" ordered logically.
408 ///
409 /// For example:
410 ///
411 /// %0 = op ... : tensor<?x?x4x5xf32>
412 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
413 ///
414 /// and reshape:
415 /// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] :
416 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
417 ///
418 /// would be rewritten into:
419 /// %0 = op ... : tensor<?x?x4x5xf32>
420 /// with output index_map
421 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
422 template <typename TensorReshapeOp>
423 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
424                                         TensorReshapeOp reshapeOp) {
425   constexpr bool isExpanding =
426       std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
427   ArrayRef<int64_t> sourceShape =
428       (isExpanding ? reshapeOp.getResultType().getShape()
429                    : reshapeOp.getSrcType().getShape());
430   SmallVector<AffineExpr> resultExprs;
431   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
432   MLIRContext *context = sourceMap.getContext();
433 
434   // Compute the result exprs based on the reassociation maps.
435   for (auto &indices : reshapeOp.getReassociationIndices()) {
436     // Assume that they are in-order and contiguous (already checked in
437     // verifier).
438     assert(!indices.empty());
439     SmallVector<int64_t> sizes;
440     SmallVector<AffineExpr> dimExprs;
441     for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
442                              sourceExprs.slice(indices[0], indices.size()))) {
443       if (std::get<0>(en) == 1)
444         continue;
445       sizes.push_back(std::get<0>(en));
446       dimExprs.push_back(std::get<1>(en));
447     }
448     AffineExpr linearizedExpr =
449         makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
450     resultExprs.push_back(linearizedExpr);
451   }
452   // The new affine map cannot drop unused dimension but some new symbols may
453   // have been added. Create a map with at least as many dimensions/symbols as
454   // the original affine map.
455   int64_t maxDim = -1;
456   int64_t maxSym = -1;
457   getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym);
458   unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims());
459   unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols());
460   return AffineMap::get(numDims, numSyms, resultExprs, context);
461 }
462 
463 // tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a
464 // producer). Fusing when operand has higher rank will require use of mods and
465 // divs in the indexing maps of the fused op which would make it non-invertible.
466 static bool isTensorReshapeOpFoldableByLinearization(
467     tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
468   if (!asProducer)
469     return false;
470   return useIndexMap.isPermutation();
471 }
472 
473 // tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a
474 // consumer).
475 static bool
476 isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp,
477                                          AffineMap useIndexMap,
478                                          bool asProducer) {
479   if (asProducer)
480     return false;
481   return useIndexMap.isPermutation();
482 }
483 
484 /// Check if the reshape operation is only expansion into/collapsing of
485 /// unit-dimension.
486 template <typename TensorReshapeOp>
487 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
488   constexpr bool isExpanding =
489       std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
490   ArrayRef<int64_t> expandedShape =
491       (isExpanding ? reshapeOp.getResultType().getShape()
492                    : reshapeOp.getSrcType().getShape());
493   for (auto &indices : reshapeOp.getReassociationIndices()) {
494     unsigned numUnitDims = 0;
495     for (int64_t position : indices)
496       if (expandedShape[position] == 1)
497         numUnitDims++;
498     if (numUnitDims != indices.size() - 1)
499       return false;
500   }
501   return true;
502 }
503 
504 namespace {
505 /// Pattern to fold tensor_expand_shape op with its consumer by using the source
506 /// of the reshape op as the operand in the consumer (instead of the result of
507 /// the tensor_collapse_shape). The corresponding index map in the consumer
508 /// needs to be modified to linearize the folded dimension.
509 ///
510 /// For example,
511 ///
512 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
513 /// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
514 ///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
515 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
516 ///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
517 ///        -> tensor<?x?x4x?xf32>
518 ///
519 /// can be folded into
520 ///
521 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
522 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
523 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
524 ///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
525 ///        -> tensor<?x?x4x?xf32>
526 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
527 struct FoldProducerReshapeOpByLinearization
528     : public OpRewritePattern<GenericOp> {
529   using OpRewritePattern<GenericOp>::OpRewritePattern;
530 
531   LogicalResult matchAndRewrite(GenericOp genericOp,
532                                 PatternRewriter &rewriter) const override {
533     if (!genericOp.hasTensorSemantics())
534       return failure();
535     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
536     for (const auto &en : llvm::enumerate(inputOperands)) {
537       auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
538       if (!reshapeOp)
539         continue;
540 
541       if (!isTensorReshapeOpFoldableByLinearization(
542               reshapeOp, genericOp.getTiedIndexingMap(en.value()),
543               /*asProducer =*/true) ||
544           (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
545         continue;
546 
547       // Compute the fused operands list,
548       SmallVector<Value> fusedOperands = genericOp.getInputOperands();
549       fusedOperands[en.index()] = reshapeOp.src();
550       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
551       llvm::append_range(fusedOperands, outputOperands);
552 
553       // Compute indexing_maps for the fused operation. The indexing_maps for
554       // the operands of the consumers that arent fused are the same.
555       SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
556 
557       // Compute the indexing map to use for the result of the producer.
558       AffineMap modifiedMap =
559           linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
560       // The modified map cannot have symbols.
561       if (modifiedMap.getNumSymbols())
562         return failure();
563       for (AffineExpr expr : modifiedMap.getResults()) {
564         if (!expr.isPureAffine())
565           return failure();
566       }
567       fusedIndexMaps[en.index()] = modifiedMap;
568 
569       // Further check that the resulting index maps can be fused and
570       // inverted. Without this the resultant op is not legal.
571       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
572         return rewriter.notifyMatchFailure(
573             genericOp, "fused op loop bound computation failed");
574       }
575 
576       rewriter.startRootUpdate(genericOp);
577       genericOp->setOperands(fusedOperands);
578       genericOp.indexing_mapsAttr(
579           rewriter.getAffineMapArrayAttr(fusedIndexMaps));
580       rewriter.finalizeRootUpdate(genericOp);
581       return success();
582     }
583     return failure();
584   }
585 };
586 
587 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
588 /// producer. The corresponding index map in the consumer needs to be modified
589 /// to linearize the folded dimension.
590 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
591 struct FoldConsumerReshapeOpByLinearization
592     : public OpRewritePattern<TensorReshapeOp> {
593   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
594 
595   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
596                                 PatternRewriter &rewriter) const override {
597     GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
598     if (!producer || !producer.hasTensorSemantics() ||
599         producer.getNumOutputs() != 1 ||
600         !isTensorReshapeOpFoldableByLinearization(
601             reshapeOp,
602             producer.getTiedIndexingMap(producer.getOutputOperand(0)),
603             /*asProducer =*/false) ||
604         (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
605       return failure();
606     // The indexing_maps for the operands of the fused operation are same as
607     // those for the operands of the producer.
608     SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
609 
610     // Compute the indexing map to use for the operand of the producer.
611     AffineMap modifiedMap = linearizeCollapsedDims(
612         producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
613     for (AffineExpr expr : modifiedMap.getResults()) {
614       if (!expr.isPureAffine()) {
615         return rewriter.notifyMatchFailure(
616             producer, "fused op indexing map is not affine");
617       }
618     }
619     fusedIndexMaps.back() = modifiedMap;
620 
621     // Further check that the resulting index maps can be fused and
622     // inverted. Without this the resultant op is not legal.
623     if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
624       return rewriter.notifyMatchFailure(
625           producer, "fused op loop bound computation failed");
626     }
627 
628     Location loc = producer.getLoc();
629     SmallVector<Value> inputOperands = producer.getInputOperands();
630     Value output = rewriter.create<TensorReshapeOp>(
631         loc, producer.getOutputOperand(0)->get(),
632         reshapeOp.getReassociationExprs());
633     auto fusedOp = rewriter.create<GenericOp>(
634         loc, reshapeOp.getResultType(),
635         /*inputs=*/inputOperands,
636         // TODO: handle outputs.
637         /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
638         producer.iterator_types(),
639         /*doc=*/nullptr,
640         /*library_call=*/nullptr);
641     auto &fusedRegion = fusedOp->getRegion(0);
642     rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
643                                fusedRegion.begin());
644     rewriter.replaceOp(reshapeOp, fusedOp->getResults());
645     return success();
646   }
647 };
648 } // namespace
649 
650 //===---------------------------------------------------------------------===//
651 // Methods and patterns that fuse reshape ops with elementwise operations by
652 // expanding the dimensionality of the elementwise operations.
653 //===---------------------------------------------------------------------===//
654 
655 /// Conditions for folding a generic operation with a reshape op by expanding
656 /// the iteration space dimensionality for tensor operations. These are
657 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
658 /// following fusion pattern.
659 ///
660 ///  Consider
661 ///
662 ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
663 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
664 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
665 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
666 ///  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
667 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
668 ///
669 ///  The reshape can be folded into the `genericOp` if its loop dimensionality
670 ///  is increased to match the result (operand) of the tensor_expand_shape.
671 ///  The indexing_map of the fused tensor in the `genericOp` and the
672 ///  reassociation map helps compute the indexing maps of the modified op.
673 ///  For the above example, based on the reassociation map it
674 ///  can be concluded that
675 ///
676 ///  - The loop used to access the first dimension of the fused tensor is split
677 ///    into two.
678 ///  - The loop used to access the second dimension of the fused tensor is kept
679 ///    as is.
680 ///  - The loop used to access the third dimension of the fused tensor is split
681 ///    into three.
682 ///
683 ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
684 ///  op, then
685 ///
686 ///   d0 -> e0, e1
687 ///   d1 -> e2, e3, e4
688 ///   d2 -> e5
689 ///
690 ///  substituting this, the generic op can be rewritten as
691 ///
692 ///  %d = linalg.generic ins(%0, %1 : )
693 ///        indexing_maps =
694 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
695 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
696 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
697 ///
698 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
699 ///  to make it consistent
700 ///
701 ///  %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
702 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
703 ///  %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
704 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
705 ///
706 ///  The added reshapes are again expanding patterns, so they will get fused
707 ///  with its producers if possible.
708 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
709                                                OpOperand *fusableOpOperand) {
710   // Is fusable only if:
711   // - All the indexing maps for operands and results are projected
712   //   permutations.
713   // - The fused tensor is not a scalar.
714   // - All the loops are parallel loops.
715   return genericOp.hasTensorSemantics() &&
716          llvm::all_of(genericOp.indexing_maps().getValue(),
717                       [](Attribute attr) {
718                         return attr.cast<AffineMapAttr>()
719                             .getValue()
720                             .isProjectedPermutation();
721                       }) &&
722          genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
723          llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
724            return attr.cast<StringAttr>().getValue() ==
725                   getParallelIteratorTypeName();
726          });
727 }
728 
729 namespace {
730 /// Information needed to expand a generic operation to fold the reshape with
731 /// it.
732 class ExpansionInfo {
733 public:
734   // Computes the mapping from original dimensions of the op to the dimensions
735   // of the expanded op given the `indexingMap` of the fused operand/result of
736   // the generic op, the `reassocationMaps` of the reshape op and the shape of
737   // the expanded op.
738   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
739                         ArrayRef<AffineMap> reassociationMaps,
740                         ArrayRef<int64_t> expandedShape,
741                         ArrayRef<int64_t> collapsedShape,
742                         PatternRewriter &rewriter);
743   unsigned getOrigOpNumDims() const { return reassociation.size(); }
744   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
745   ReassociationIndicesRef getExpandedDims(unsigned i) const {
746     return reassociation[i];
747   }
748   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
749     return expandedShapeMap[i];
750   }
751   ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
752 
753 private:
754   /// Reassociation from the dimensions in the original operation to the
755   /// dimension of the expanded operation.
756   SmallVector<ReassociationIndices> reassociation;
757   /// Mapping from extent of loops in the original operation, to the extent of
758   /// loops in the expanded operation.
759   SmallVector<SmallVector<int64_t>> expandedShapeMap;
760   /// Extent of the loop in the original operation.
761   SmallVector<int64_t> originalLoopExtent;
762   unsigned expandedOpNumDims;
763 };
764 } // namespace
765 
766 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
767                                      OpOperand *fusableOpOperand,
768                                      ArrayRef<AffineMap> reassociationMaps,
769                                      ArrayRef<int64_t> expandedShape,
770                                      ArrayRef<int64_t> collapsedShape,
771                                      PatternRewriter &rewriter) {
772   if (reassociationMaps.empty())
773     return failure();
774   AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
775 
776   Optional<SmallVector<int64_t, 4>> originalLoopRange =
777       linalgOp.getStaticLoopRanges();
778   if (!originalLoopRange)
779     return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
780   originalLoopExtent.assign(originalLoopRange->begin(),
781                             originalLoopRange->end());
782 
783   reassociation.clear();
784   expandedShapeMap.clear();
785   // Compute the number of dimension in the expanded op that correspond to each
786   // dimension of the original op.
787   SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
788   expandedShapeMap.resize(fusedIndexMap.getNumDims());
789   for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
790     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
791     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
792     numExpandedDims[pos] = foldedDims.getNumResults();
793     ArrayRef<int64_t> shape =
794         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
795     expandedShapeMap[pos].assign(shape.begin(), shape.end());
796   }
797   // The remaining dimensions remain the same.
798   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
799     if (expandedShapeMap[i].empty())
800       expandedShapeMap[i] = {originalLoopExtent[i]};
801 
802   // Compute reassociation map from the original op to the expanded op.
803   unsigned sum = 0;
804   reassociation.reserve(fusedIndexMap.getNumDims());
805   for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
806     auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
807     reassociation.emplace_back(seq.begin(), seq.end());
808     sum += numFoldedDim.value();
809   }
810   expandedOpNumDims = sum;
811   return success();
812 }
813 
814 /// Epanding the body of a linalg operation requires adaptations of the accessed
815 /// loop indices. Specifically, access of indices in the original operation need
816 /// to be replaced with linearizations of indices in the expanded op. That
817 /// requires the shape of the expanded dimensions to be static (at least all but
818 /// the most significant). For now check that these are all statically sized.
819 /// Note that this could be extended to handle dynamic case, but the
820 /// implementation below uses `affine.apply` which seems to have issues when the
821 /// shapes are not static.
822 static LogicalResult isGenericOpExpandable(GenericOp genericOp,
823                                            const ExpansionInfo &expansionInfo,
824                                            PatternRewriter &rewriter) {
825   if (!genericOp.hasIndexSemantics())
826     return success();
827   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
828     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
829     if (expandedShape.size() == 1)
830       continue;
831     for (int64_t shape : expandedShape.drop_front()) {
832       if (ShapedType::isDynamic(shape)) {
833         return rewriter.notifyMatchFailure(
834             genericOp, "cannot expand due to index semantics and dynamic dims");
835       }
836     }
837   }
838   return success();
839 }
840 
841 /// Return the indexing map to use in the expanded op for a given the
842 /// `indexingMap` of the original operation.
843 static AffineMap
844 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
845                            const ExpansionInfo &expansionInfo) {
846   SmallVector<AffineExpr> newExprs;
847   for (AffineExpr expr : indexingMap.getResults()) {
848     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
849     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
850         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
851           return builder.getAffineDimExpr(static_cast<unsigned>(v));
852         }));
853     newExprs.append(expandedExprs.begin(), expandedExprs.end());
854   }
855   return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
856                         indexingMap.getNumSymbols(), newExprs,
857                         builder.getContext());
858 }
859 
860 /// Return the type of the operand/result to use in the expanded op given the
861 /// type in the original op.
862 static RankedTensorType getExpandedType(RankedTensorType originalType,
863                                         AffineMap indexingMap,
864                                         const ExpansionInfo &expansionInfo) {
865   SmallVector<int64_t> expandedShape;
866   for (AffineExpr expr : indexingMap.getResults()) {
867     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
868     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
869     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
870   }
871   return RankedTensorType::get(expandedShape, originalType.getElementType());
872 }
873 
874 /// Returns the reassociation maps to use in the `tensor.expand_shape`
875 /// operation to convert the operands of the original operation to operands of
876 /// the expanded operation. The same method is used to compute the
877 /// `tensor.collapse_shape` used to collapse the result of the expanded
878 /// op to get the value that can replace all uses of the results of the original
879 /// op.
880 static SmallVector<ReassociationIndices>
881 getReassociationForExpansion(AffineMap indexingMap,
882                              const ExpansionInfo &expansionInfo) {
883   SmallVector<ReassociationIndices> reassociation;
884   unsigned numReshapeDims = 0;
885   for (AffineExpr expr : indexingMap.getResults()) {
886     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
887     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
888     SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
889         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
890     reassociation.emplace_back(std::move(indices));
891     numReshapeDims += numExpandedDims;
892   }
893   return reassociation;
894 }
895 
896 /// Update the body of an expanded linalg operation having index semantics. The
897 /// indices of the original operation need to be recovered by linearizing the
898 /// indices of the correspoding dimensions of the expanded operation. For now it
899 /// is assumed that the shapes of the expanded operation needed for
900 /// linearization are static.
901 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
902                                           Location loc, Region &fusedRegion,
903                                           const ExpansionInfo &expansionInfo) {
904   // Replace the original indices by the linearization of the expanded indices.
905   for (IndexOp indexOp :
906        llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
907     ArrayRef<int64_t> expandedDims =
908         expansionInfo.getExpandedDims(indexOp.dim());
909     assert(!expandedDims.empty() && "expected valid expansion info");
910 
911     // Skip index operations that are not affected by the expansion.
912     if (expandedDims.size() == 1 &&
913         expandedDims.front() == (int64_t)indexOp.dim())
914       continue;
915 
916     // Linearize the expanded indices of the original index dimension.
917     OpBuilder::InsertionGuard guard(rewriter);
918     rewriter.setInsertionPointAfter(indexOp);
919     ArrayRef<int64_t> expandedDimsShape =
920         expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
921     SmallVector<Value> expandedIndices;
922     expandedIndices.reserve(expandedDims.size() - 1);
923     llvm::transform(
924         expandedDims.drop_front(), std::back_inserter(expandedIndices),
925         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
926     Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
927     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
928       assert(!ShapedType::isDynamic(std::get<0>(it)));
929       AffineExpr idx, acc;
930       bindDims(rewriter.getContext(), idx, acc);
931       newIndex = rewriter.create<AffineApplyOp>(
932           indexOp.getLoc(), idx + acc * std::get<0>(it),
933           ValueRange{std::get<1>(it), newIndex});
934     }
935     rewriter.replaceOp(indexOp, newIndex);
936   }
937 }
938 
939 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op
940 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
941 /// that those conditions have been satisfied.
942 static Optional<SmallVector<Value>>
943 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
944                            OpOperand *fusableOpOperand,
945                            PatternRewriter &rewriter) {
946   assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
947          "preconditions for fuse operation failed");
948   // Check if reshape is expanding or collapsing.
949   auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
950   auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
951   bool isExpanding = (expandingReshapeOp != nullptr);
952   RankedTensorType expandedType = isExpanding
953                                       ? expandingReshapeOp.getResultType()
954                                       : collapsingReshapeOp.getSrcType();
955   RankedTensorType collapsedType = isExpanding
956                                        ? expandingReshapeOp.getSrcType()
957                                        : collapsingReshapeOp.getResultType();
958 
959   ExpansionInfo expansionInfo;
960   if (failed(expansionInfo.compute(
961           genericOp, fusableOpOperand,
962           isExpanding ? expandingReshapeOp.getReassociationMaps()
963                       : collapsingReshapeOp.getReassociationMaps(),
964           expandedType.getShape(), collapsedType.getShape(), rewriter)))
965     return llvm::None;
966 
967   if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
968     return llvm::None;
969 
970   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
971       llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
972         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
973       }));
974 
975   SmallVector<Value> expandedOpOperands;
976   expandedOpOperands.reserve(genericOp.getNumInputs());
977   for (OpOperand *opOperand : genericOp.getInputOperands()) {
978     if (opOperand == fusableOpOperand) {
979       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
980                                                : collapsingReshapeOp.src());
981       continue;
982     }
983     if (genericOp.isInputTensor(opOperand)) {
984       AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
985       auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
986       RankedTensorType expandedOperandType =
987           getExpandedType(opOperandType, indexingMap, expansionInfo);
988       if (expandedOperandType != opOperand->get().getType()) {
989         // Reshape the operand to get the right type.
990         SmallVector<ReassociationIndices> reassociation =
991             getReassociationForExpansion(indexingMap, expansionInfo);
992         if (failed(reshapeLikeShapesAreCompatible(
993                 [&](const Twine &msg) {
994                   return rewriter.notifyMatchFailure(genericOp, msg);
995                 },
996                 opOperandType.getShape(), expandedOperandType.getShape(),
997                 reassociation,
998                 /*isExpandingReshape=*/true)))
999           return llvm::None;
1000         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
1001             genericOp.getLoc(), expandedOperandType, opOperand->get(),
1002             reassociation));
1003         continue;
1004       }
1005     }
1006     expandedOpOperands.push_back(opOperand->get());
1007   }
1008 
1009   Location loc = genericOp.getLoc();
1010   SmallVector<Value> outputs;
1011   for (OpOperand *opOperand : genericOp.getOutputOperands()) {
1012     AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
1013     auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
1014     RankedTensorType expandedOutputType =
1015         getExpandedType(opOperandType, indexingMap, expansionInfo);
1016     if (expandedOutputType != opOperand->get().getType()) {
1017       SmallVector<ReassociationIndices> reassociation =
1018           getReassociationForExpansion(indexingMap, expansionInfo);
1019       if (failed(reshapeLikeShapesAreCompatible(
1020               [&](const Twine &msg) {
1021                 return rewriter.notifyMatchFailure(genericOp, msg);
1022               },
1023               opOperandType.getShape(), expandedOutputType.getShape(),
1024               reassociation,
1025               /*isExpandingReshape=*/true)))
1026         return llvm::None;
1027       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
1028           genericOp.getLoc(), expandedOutputType, opOperand->get(),
1029           reassociation));
1030     }
1031   }
1032 
1033   // The iterator types of the expanded op are all parallel.
1034   SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
1035                                        getParallelIteratorTypeName());
1036 
1037   TypeRange resultTypes = ValueRange(outputs).getTypes();
1038   auto fusedOp =
1039       rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
1040                                  /*inputs=*/expandedOpOperands, outputs,
1041                                  expandedOpIndexingMaps, iteratorTypes);
1042   Region &fusedRegion = fusedOp->getRegion(0);
1043   Region &originalRegion = genericOp->getRegion(0);
1044   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
1045 
1046   // Update the index accesses after the expansion.
1047   updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
1048 
1049   // Reshape the result values to their original shape if this is a collapsing
1050   // reshape folded into its consumer.
1051   SmallVector<Value> resultVals;
1052   for (OpResult opResult : genericOp->getOpResults()) {
1053     int64_t resultNumber = opResult.getResultNumber();
1054     if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
1055       SmallVector<ReassociationIndices> reassociation =
1056           getReassociationForExpansion(
1057               genericOp.getTiedIndexingMap(
1058                   genericOp.getOutputOperand(resultNumber)),
1059               expansionInfo);
1060       resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
1061           genericOp.getLoc(), opResult.getType(),
1062           fusedOp->getResult(resultNumber), reassociation));
1063     } else {
1064       resultVals.push_back(fusedOp->getResult(resultNumber));
1065     }
1066   }
1067   // Assuming a single result.
1068   return resultVals;
1069 }
1070 
1071 namespace {
1072 
1073 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
1074 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
1075 /// in the consumer is expanded.
1076 class FoldWithProducerReshapeOpByExpansion
1077     : public OpRewritePattern<GenericOp> {
1078 public:
1079   FoldWithProducerReshapeOpByExpansion(
1080       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1081       PatternBenefit benefit = 1)
1082       : OpRewritePattern<GenericOp>(context, benefit),
1083         controlFoldingReshapes(std::move(foldReshapes)) {}
1084 
1085   LogicalResult matchAndRewrite(GenericOp genericOp,
1086                                 PatternRewriter &rewriter) const override {
1087     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1088       tensor::CollapseShapeOp reshapeOp =
1089           opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1090       if (!reshapeOp)
1091         continue;
1092       // Fold only if
1093       // - The tensor reshape op is folding.
1094       // - All constraints of fusing with reshape by expansion are met.
1095       if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
1096           (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
1097         continue;
1098 
1099       Optional<SmallVector<Value>> replacementValues =
1100           fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
1101       if (!replacementValues)
1102         return failure();
1103       rewriter.replaceOp(genericOp, replacementValues.getValue());
1104       return success();
1105     }
1106     return failure();
1107   }
1108 
1109 private:
1110   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1111 };
1112 
1113 /// Pattern to fold a tensor_expand_shape op with its producer generic op
1114 /// by expanding the dimensionality of the loop in the producer op.
1115 struct FoldReshapeWithGenericOpByExpansion
1116     : public OpRewritePattern<tensor::ExpandShapeOp> {
1117 
1118   FoldReshapeWithGenericOpByExpansion(
1119       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
1120       PatternBenefit benefit = 1)
1121       : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1122         controlFoldingReshapes(std::move(foldReshapes)) {}
1123 
1124   LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1125                                 PatternRewriter &rewriter) const override {
1126     // Fold only if all constraints of fusing with reshape by expansion are met.
1127     GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
1128     if (!producer || producer.getNumOutputs() != 1 ||
1129         !isFusableWithReshapeByDimExpansion(producer,
1130                                             producer.getOutputOperand(0)) ||
1131         !controlFoldingReshapes(producer->getResult(0),
1132                                 reshapeOp->getOpOperand(0)))
1133       return failure();
1134     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
1135         producer, reshapeOp, producer.getOutputOperand(0), rewriter);
1136     if (!replacementValues)
1137       return failure();
1138     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
1139     return success();
1140   }
1141 
1142 private:
1143   ControlElementwiseOpsFusionFn controlFoldingReshapes;
1144 };
1145 } // namespace
1146 
1147 //===---------------------------------------------------------------------===//
1148 // Methods and patterns to convert tensor.expand_shape -> linalg.generic
1149 // into linalg.generic -> tensor.expand_shape, i.e. push the reshape down.
1150 //===---------------------------------------------------------------------===//
1151 
1152 static SmallVector<ReassociationIndices>
1153 getReassociationIndices(ArrayRef<AffineMap> maps) {
1154   SmallVector<ReassociationIndices> reassociation;
1155   for (AffineMap map : maps) {
1156     ReassociationIndices indices;
1157     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
1158       unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
1159       indices.push_back(pos);
1160     }
1161     reassociation.push_back(indices);
1162   }
1163   return reassociation;
1164 }
1165 
1166 namespace {
1167 /// Pattern to move rank reducing reshape after an elementwise linalg generic
1168 /// op. This is useful to expose more fusion opportunities between named ops and
1169 /// generic ops. This can only be done if there is no broadcast or permuation
1170 /// within the dimensions we need to merge.
1171 ///
1172 /// For example,
1173 ///
1174 ///  %0 = tensor.expand_shape %A [[0, 1], [2]]
1175 ///      : tensor<12544x16xf32> into tensor<112x112x16xf32>
1176 ///  %2 = linalg.generic {indexing_maps = [
1177 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1178 ///    affine_map<(d0, d1, d2) -> (d2)>,
1179 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
1180 ///    ["parallel", "parallel", "parallel"]} {
1181 ///  } -> tensor<112x112x16xf32>
1182 ///
1183 ///  into
1184 ///
1185 ///  %2 = linalg.generic {indexing_maps = [
1186 ///    affine_map<(d0, d1) -> (d0, d1)>,
1187 ///    affine_map<(d0, d1) -> (d1)>,
1188 ///    affine_map<(d0, d1) -> (d0, d1)>],
1189 ///    iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
1190 ///    : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
1191 ///  } -> tensor<12544x16xf32>
1192 ///  %3 = tensor.expand_shape %2 [[0, 1], [2]]
1193 ///    : tensor<12544x16xf32> into tensor<112x112x16xf32>
1194 struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
1195   using OpRewritePattern<GenericOp>::OpRewritePattern;
1196 
1197   LogicalResult matchAndRewrite(GenericOp genericOp,
1198                                 PatternRewriter &rewriter) const override {
1199     // Only apply to elementwise linalg on tensor.
1200     if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() ||
1201         genericOp.getNumParallelLoops() != genericOp.getNumLoops())
1202       return failure();
1203     // Only support identity output maps. It could be extended to permuations if
1204     // needed.
1205     if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
1206           return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
1207         }))
1208       return failure();
1209     int64_t destRank = genericOp.getNumParallelLoops();
1210     SmallVector<Value> newOperands = genericOp.getInputOperands();
1211     tensor::ExpandShapeOp reshapeFound;
1212     // 1. Look for tensor_expand_shape operands and figure out save the
1213     // dimensions merged.
1214     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
1215     for (const auto &en : llvm::enumerate(inputOperands)) {
1216       auto reshapeOp =
1217           en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>();
1218       if (!reshapeOp)
1219         continue;
1220       // TODO: We could support non-identity map as long as the merged
1221       // dimensions are still contiguous.
1222       if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
1223         continue;
1224       if (reshapeFound) {
1225         // Only support a second reshape op if it has the same reassociate maps.
1226         if (reshapeFound.getReassociationMaps() ==
1227             reshapeOp.getReassociationMaps())
1228           newOperands[en.index()] = reshapeOp.src();
1229         continue;
1230       }
1231       reshapeFound = reshapeOp;
1232       newOperands[en.index()] = reshapeOp.src();
1233     }
1234     if (!reshapeFound)
1235       return failure();
1236 
1237     // Calculate the reassociation indices and rassociated reverse map.
1238     SmallVector<ReassociationIndices> reassociation =
1239         getReassociationIndices(reshapeFound.getReassociationMaps());
1240     SmallVector<unsigned> remap(destRank);
1241     for (auto &indices : llvm::enumerate(reassociation)) {
1242       for (int64_t index : indices.value()) {
1243         remap[index] = indices.index();
1244       }
1245     }
1246     // 2. Verify that we can merge the dimensions in the linalg and that we
1247     // don't need to create new reshapes operands. Inserting new reshape
1248     // operands would defeat the purpose of the transformation.
1249     for (const auto &en : llvm::enumerate(inputOperands)) {
1250       if (en.value()->get() == newOperands[en.index()]) {
1251         AffineMap map = genericOp.getTiedIndexingMap(en.value());
1252         for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
1253           if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
1254             return failure();
1255         }
1256       }
1257     }
1258 
1259     // 3. Calculate the affine map remapping and the reassociation to apply to
1260     // output tensors.
1261     SmallVector<AffineMap> newMaps;
1262     unsigned newRank = reassociation.size();
1263     for (auto map : genericOp.getIndexingMaps()) {
1264       SmallVector<AffineExpr> newExprs;
1265       for (auto expr : map.getResults()) {
1266         unsigned position = expr.template cast<AffineDimExpr>().getPosition();
1267         // Skip dimension merged except for the last of the group.
1268         if (reassociation[remap[position]].back() == position) {
1269           newExprs.push_back(
1270               getAffineDimExpr(remap[position], genericOp.getContext()));
1271         }
1272       }
1273       newMaps.push_back(
1274           AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
1275     }
1276 
1277     // 4. Reshape the output tensors.
1278     SmallVector<Value> newOutputs;
1279     SmallVector<Type> newOutputTypes;
1280     for (auto output : genericOp.outputs()) {
1281       auto newOutputType = RankedTensorType::get(
1282           reshapeFound.getSrcType().getShape(),
1283           output.getType().template cast<RankedTensorType>().getElementType());
1284       Value newOutput = rewriter.create<tensor::CollapseShapeOp>(
1285           genericOp->getLoc(), newOutputType, output, reassociation);
1286       newOutputTypes.push_back(newOutputType);
1287       newOutputs.push_back(newOutput);
1288     }
1289     // 5. Create a new generic op with lowerer rank.
1290     SmallVector<StringRef> iteratorTypes(newRank,
1291                                          getParallelIteratorTypeName());
1292     auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
1293                                             newOperands, newOutputs, newMaps,
1294                                             iteratorTypes);
1295     rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
1296                                 newOp.region().begin());
1297     // 6. Reshape the so that the type matches the uses.
1298     SmallVector<Value> newResults;
1299     for (const auto &result : llvm::enumerate(newOp->getResults())) {
1300       newResults.push_back(rewriter.create<tensor::ExpandShapeOp>(
1301           genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
1302           result.value(), reassociation));
1303     }
1304     rewriter.replaceOp(genericOp, newResults);
1305     return success();
1306   }
1307 };
1308 } // namespace
1309 
1310 //===---------------------------------------------------------------------===//
1311 // Methods and patterns that fuse constants with linalg.generic operations.
1312 //===---------------------------------------------------------------------===//
1313 
1314 namespace {
1315 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1316 /// handle cases where the constant is not single-valued.
1317 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1318 public:
1319   FoldScalarOrSplatConstant(MLIRContext *context,
1320                             ControlElementwiseOpsFusionFn &fun,
1321                             PatternBenefit benefit = 1)
1322       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
1323 
1324   LogicalResult matchAndRewrite(GenericOp genericOp,
1325                                 PatternRewriter &rewriter) const override {
1326     if (!genericOp.hasTensorSemantics())
1327       return failure();
1328     for (OpOperand *opOperand : genericOp.getInputOperands()) {
1329       Operation *def = opOperand->get().getDefiningOp();
1330       Attribute constantAttr;
1331       auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1332         {
1333           DenseElementsAttr splatAttr;
1334           if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1335               splatAttr.isSplat() &&
1336               splatAttr.getType().getElementType().isIntOrFloat()) {
1337             constantAttr = splatAttr.getSplatValue<Attribute>();
1338             return true;
1339           }
1340         }
1341         {
1342           IntegerAttr intAttr;
1343           if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1344             constantAttr = intAttr;
1345             return true;
1346           }
1347         }
1348         {
1349           FloatAttr floatAttr;
1350           if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1351             constantAttr = floatAttr;
1352             return true;
1353           }
1354         }
1355         return false;
1356       };
1357 
1358       auto resultValue = opOperand->get().dyn_cast<OpResult>();
1359       if (!def || !resultValue || !isScalarOrSplatConstantOp(def) ||
1360           !controlFn(resultValue, *opOperand))
1361         continue;
1362 
1363       // The operands and the indexing_maps of the fused operation the same as
1364       // the operands and indexing_maps of the generic operations with the
1365       // values at the constant index dropped.
1366       SmallVector<AffineMap> fusedIndexMaps;
1367       SmallVector<Value> fusedOperands;
1368       SmallVector<Location> fusedLocs{genericOp.getLoc()};
1369       fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1370       fusedOperands.reserve(genericOp.getNumInputs());
1371       fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
1372       for (OpOperand *inputOperand : genericOp.getInputOperands()) {
1373         if (inputOperand == opOperand)
1374           continue;
1375         Value inputValue = inputOperand->get();
1376         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1377         fusedOperands.push_back(inputValue);
1378         fusedLocs.push_back(inputValue.getLoc());
1379       }
1380       for (OpOperand *outputOperand : genericOp.getOutputOperands())
1381         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1382 
1383       // Check if the operation shapes to loops map is computable.
1384       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1385         return rewriter.notifyMatchFailure(
1386             genericOp, "fused op loop bound computation failed");
1387       }
1388 
1389       // Create a constant scalar value from the splat constant.
1390       Value scalarConstant = rewriter.create<arith::ConstantOp>(
1391           def->getLoc(), constantAttr, constantAttr.getType());
1392 
1393       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1394       auto fusedOp = rewriter.create<GenericOp>(
1395           rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1396           /*inputs=*/fusedOperands,
1397           /*outputs=*/outputOperands,
1398           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1399           genericOp.iterator_types(),
1400           /*doc=*/nullptr,
1401           /*library_call=*/nullptr);
1402 
1403       // Map the block argument corresponding to the replaced argument with the
1404       // scalar constant.
1405       Region &region = genericOp->getRegion(0);
1406       Block &entryBlock = *region.begin();
1407       BlockAndValueMapping mapping;
1408       mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1409                   scalarConstant);
1410       Region &fusedRegion = fusedOp->getRegion(0);
1411       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1412                                  mapping);
1413       rewriter.replaceOp(genericOp, fusedOp->getResults());
1414       return success();
1415     }
1416     return failure();
1417   }
1418 
1419 private:
1420   ControlElementwiseOpsFusionFn controlFn;
1421 };
1422 
1423 /// Base class for constant folding linalg.generic ops with N inputs, 1 output,
1424 /// and permutation indexing maps.
1425 ///
1426 /// `ConcreteType` should provide methods with signatures
1427 ///
1428 /// ```c++
1429 ///   bool matchIndexingMaps(GenericOp genericOp) const;
1430 ///   RegionComputationFn getRegionComputeFn(GenericOp) const;
1431 /// ```
1432 ///
1433 /// The latter inspects the region and returns the computation inside as a
1434 /// functor. The functor will be invoked with constant elements for all inputs
1435 /// and should return the corresponding computea constant element for output.
1436 template <typename ConcreteType>
1437 class FoldConstantBase : public OpRewritePattern<GenericOp> {
1438 public:
1439   struct APIntOrFloat {
1440     Optional<APInt> apInt;
1441     Optional<APFloat> apFloat;
1442   };
1443   struct APIntOrFloatArray {
1444     SmallVector<APInt> apInts;
1445     SmallVector<APFloat> apFloats;
1446   };
1447   using RegionComputationFn =
1448       std::function<APIntOrFloat(const APIntOrFloatArray &)>;
1449 
1450   FoldConstantBase(MLIRContext *context,
1451                    const ControlElementwiseOpsFusionFn &controlFn,
1452                    PatternBenefit benefit = 1)
1453       : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
1454 
1455   LogicalResult matchAndRewrite(GenericOp genericOp,
1456                                 PatternRewriter &rewriter) const override {
1457     if (genericOp.hasBufferSemantics())
1458       return failure();
1459 
1460     // Only support ops generating one output for now.
1461     if (genericOp.getNumOutputs() != 1)
1462       return failure();
1463 
1464     auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
1465     // Require the output types to be static give we are generating constants.
1466     if (!outputType || !outputType.hasStaticShape())
1467       return failure();
1468 
1469     if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
1470           return operand->get().getType().isa<ShapedType>();
1471         }))
1472       return failure();
1473 
1474     // Make sure all element types are the same.
1475     auto getOperandElementType = [](OpOperand *operand) {
1476       return operand->get().getType().cast<ShapedType>().getElementType();
1477     };
1478     if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
1479                                         getOperandElementType)))
1480       return failure();
1481 
1482     // We can only handle the case where we have int/float elements.
1483     auto elementType = outputType.getElementType();
1484     if (!elementType.isIntOrFloat())
1485       return failure();
1486 
1487     // Require all indexing maps to be permutations for now. This is common and
1488     // it simplifies input/output access greatly: we can do the data shuffling
1489     // entirely in the compiler, without needing to turn all indices into
1490     // Values, and then do affine apply on them, and then match back the
1491     // constant again.
1492     if (!llvm::all_of(genericOp.getIndexingMaps(),
1493                       [](AffineMap map) { return map.isPermutation(); }))
1494       return failure();
1495 
1496     for (OpOperand *operand : genericOp.getOutputOperands()) {
1497       if (genericOp.payloadUsesValueFromOperand(operand))
1498         return failure();
1499     }
1500 
1501     // Further check the indexing maps are okay for the ConcreteType.
1502     if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
1503       return failure();
1504 
1505     // Defer to the concrete type to check the region and discover the
1506     // computation inside.
1507     RegionComputationFn computeFn =
1508         static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
1509     if (!computeFn)
1510       return failure();
1511 
1512     // All inputs should be constants.
1513     int numInputs = genericOp.getNumInputs();
1514     SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
1515     for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) {
1516       if (!matchPattern(operand.value()->get(),
1517                         m_Constant(&inputValues[operand.index()])))
1518         return failure();
1519     }
1520 
1521     // Identified this as a potential candidate for folding. Now check the
1522     // policy to see whether we are allowed to proceed.
1523     for (int i = 0; i < numInputs; ++i) {
1524       OpOperand *consumer = genericOp.getInputOperand(i);
1525       OpResult producer = consumer->get().cast<OpResult>();
1526       if (!controlFn(producer, *consumer))
1527         return failure();
1528     }
1529 
1530     auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
1531     SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
1532     int64_t numElements = outputType.getNumElements();
1533 
1534     // Use APInt/APFloat instead of Attribute here for constructing the output.
1535     // This helps to avoid blowing up compiler memory usage: Attributes would
1536     // unify the following cases but they have lifetime as the MLIRContext.
1537     SmallVector<APInt> intOutputValues;
1538     SmallVector<APFloat> fpOutputValues;
1539     if (elementType.template isa<FloatType>())
1540       fpOutputValues.resize(numElements, APFloat(0.f));
1541     else
1542       intOutputValues.resize(numElements);
1543 
1544     // Return the constant dim positions from the given permutation map.
1545     auto getDimPositions = [](AffineMap map) {
1546       SmallVector<unsigned> dims;
1547       dims.reserve(map.getNumResults());
1548       for (AffineExpr result : map.getResults()) {
1549         dims.push_back(result.cast<AffineDimExpr>().getPosition());
1550       }
1551       return dims;
1552     };
1553 
1554     SmallVector<SmallVector<unsigned>> inputDims;
1555     for (int i = 0; i < numInputs; ++i)
1556       inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i]));
1557     auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
1558     auto outputShape = outputType.getShape();
1559 
1560     // Allocate small vectors for index delinearization. Initial values do not
1561     // matter here as they will be overwritten later.
1562     SmallVector<uint64_t> indices(loopBounds.size(), 0);
1563     SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
1564     SmallVector<SmallVector<uint64_t>> srcIndices(
1565         numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
1566     SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
1567     uint64_t dstLinearIndex = 0;
1568 
1569     // Allocate spaces for compute function inputs. Initial values do not matter
1570     // here as they will be overwritten later.
1571     APIntOrFloatArray computeFnInputs;
1572 
1573     auto inputShapes = llvm::to_vector<4>(
1574         llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
1575           return operand->get().getType().cast<ShapedType>().getShape();
1576         }));
1577 
1578     // Given a `linearIndex`, remap it to a linear index to access linalg op
1579     // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
1580     // `srcLinearIndices`, `dstLinearIndex` in place.
1581     auto computeRemappedLinearIndex = [&](int linearIndex) {
1582       int totalCount = linearIndex;
1583       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
1584         indices[dim] = totalCount % loopBounds[dim];
1585         totalCount /= loopBounds[dim];
1586       }
1587 
1588       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
1589         for (int i = 0; i < numInputs; ++i)
1590           srcIndices[i][dim] = indices[inputDims[i][dim]];
1591         dstIndices[dim] = indices[outputDims[dim]];
1592       }
1593 
1594       dstLinearIndex = dstIndices.front();
1595       for (int i = 0; i < numInputs; ++i)
1596         srcLinearIndices[i] = srcIndices[i].front();
1597 
1598       for (int dim = 1; dim < outputType.getRank(); ++dim) {
1599         dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
1600         for (int i = 0; i < numInputs; ++i)
1601           srcLinearIndices[i] =
1602               srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
1603       }
1604     };
1605 
1606     bool isFloat = elementType.isa<FloatType>();
1607     if (isFloat) {
1608       SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
1609       for (int i = 0; i < numInputs; ++i)
1610         inFpRanges.push_back(inputValues[i].getValues<APFloat>());
1611 
1612       computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
1613 
1614       // Transpose the input constant. Because we don't know its rank in
1615       // advance, we need to loop over the range [0, element count) and
1616       // delinearize the index.
1617       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
1618         computeRemappedLinearIndex(linearIndex);
1619 
1620         // Collect constant elements for all inputs at this loop iteration.
1621         for (int i = 0; i < numInputs; ++i)
1622           computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
1623 
1624         // Invoke the computation to get the corresponding constant output
1625         // element.
1626         fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
1627       }
1628     } else {
1629       SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
1630       for (int i = 0; i < numInputs; ++i)
1631         inIntRanges.push_back(inputValues[i].getValues<APInt>());
1632 
1633       computeFnInputs.apInts.resize(numInputs);
1634 
1635       // Transpose the input constant. Because we don't know its rank in
1636       // advance, we need to loop over the range [0, element count) and
1637       // delinearize the index.
1638       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
1639         computeRemappedLinearIndex(linearIndex);
1640 
1641         // Collect constant elements for all inputs at this loop iteration.
1642         for (int i = 0; i < numInputs; ++i)
1643           computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
1644 
1645         // Invoke the computation to get the corresponding constant output
1646         // element.
1647         intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
1648       }
1649     }
1650 
1651     DenseElementsAttr outputAttr =
1652         isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
1653                 : DenseElementsAttr::get(outputType, intOutputValues);
1654 
1655     rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
1656     return success();
1657   }
1658 
1659 private:
1660   ControlElementwiseOpsFusionFn controlFn;
1661 };
1662 
1663 // Folds linalg.generic ops that are actually transposes on constant values.
1664 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
1665   using FoldConstantBase::FoldConstantBase;
1666 
1667   bool matchIndexingMaps(GenericOp genericOp) const {
1668     // We should have one input and one output.
1669     return genericOp.getIndexingMaps().size() == 2;
1670   }
1671 
1672   RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
1673     // Make sure the region only contains a yield op.
1674     Block &body = genericOp.region().front();
1675     if (!llvm::hasSingleElement(body))
1676       return nullptr;
1677     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1678     if (!yieldOp)
1679       return nullptr;
1680 
1681     // The yield op should return the block argument corresponds to the input.
1682     for (Value yieldVal : yieldOp.values()) {
1683       auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
1684       if (!yieldArg || yieldArg.getOwner() != &body)
1685         return nullptr;
1686       if (yieldArg.getArgNumber() != 0)
1687         return nullptr;
1688     }
1689 
1690     // No computation; just return the orginal value.
1691     return [](const APIntOrFloatArray &inputs) {
1692       if (inputs.apFloats.empty())
1693         return APIntOrFloat{inputs.apInts.front(), llvm::None};
1694       return APIntOrFloat{llvm::None, inputs.apFloats.front()};
1695     };
1696   }
1697 
1698   ControlElementwiseOpsFusionFn controlFn;
1699 };
1700 
1701 } // namespace
1702 
1703 //===---------------------------------------------------------------------===//
1704 // Miscellaneous patterns that help fusion.
1705 //===---------------------------------------------------------------------===//
1706 
1707 namespace {
1708 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
1709 /// the value of the `outs` operand is not used within the op.  This is only
1710 /// implemented for `linalg.generic` operations for now, but should hold for all
1711 /// linalg structured ops.
1712 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1713   using OpRewritePattern<GenericOp>::OpRewritePattern;
1714 
1715   LogicalResult matchAndRewrite(GenericOp op,
1716                                 PatternRewriter &rewriter) const override {
1717     rewriter.startRootUpdate(op);
1718     bool modifiedOutput = false;
1719     Location loc = op.getLoc();
1720     for (OpOperand *opOperand : op.getOutputOperands()) {
1721       if (!op.payloadUsesValueFromOperand(opOperand)) {
1722         Value operandVal = opOperand->get();
1723         auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
1724         if (!operandType)
1725           continue;
1726 
1727         // If outs is already an `init_tensor` operation, nothing to do.
1728         auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
1729         if (definingOp)
1730           continue;
1731         modifiedOutput = true;
1732         SmallVector<Value> dynamicDims;
1733         for (const auto &dim : llvm::enumerate(operandType.getShape())) {
1734           if (dim.value() != ShapedType::kDynamicSize)
1735             continue;
1736           dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
1737               loc, operandVal, dim.index()));
1738         }
1739         Value initTensor = rewriter.create<InitTensorOp>(
1740             loc, dynamicDims, operandType.getShape(),
1741             operandType.getElementType());
1742         op->setOperand(opOperand->getOperandNumber(), initTensor);
1743       }
1744     }
1745     if (!modifiedOutput) {
1746       rewriter.cancelRootUpdate(op);
1747       return failure();
1748     }
1749     rewriter.finalizeRootUpdate(op);
1750     return success();
1751   }
1752 };
1753 } // namespace
1754 
1755 //===---------------------------------------------------------------------===//
1756 // Methods that add patterns descrined in this file to a pattern list.
1757 //===---------------------------------------------------------------------===//
1758 
1759 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
1760     RewritePatternSet &patterns) {
1761   patterns
1762       .add<FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
1763            FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
1764            FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
1765            FoldConsumerReshapeOpByLinearization<false, tensor::ExpandShapeOp>>(
1766           patterns.getContext());
1767 }
1768 
1769 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
1770     RewritePatternSet &patterns) {
1771   patterns
1772       .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
1773            FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>,
1774            FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
1775            FoldConsumerReshapeOpByLinearization<true, tensor::ExpandShapeOp>>(
1776           patterns.getContext());
1777 }
1778 
1779 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
1780     RewritePatternSet &patterns,
1781     const ControlElementwiseOpsFusionFn &controlFoldingReshapes) {
1782   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
1783                                                     controlFoldingReshapes);
1784   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
1785                                                      controlFoldingReshapes);
1786 }
1787 
1788 void mlir::linalg::populateElementwiseOpsFusionPatterns(
1789     RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
1790   auto *context = patterns.getContext();
1791   patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
1792                FoldConstantTranspose>(context,
1793                                       options.controlElementwiseOpsFusionFn);
1794   patterns.add<RemoveOutsDependency>(context);
1795   populateFoldReshapeOpsByExpansionPatterns(patterns,
1796                                             options.controlFoldingReshapesFn);
1797   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1798   GenericOp::getCanonicalizationPatterns(patterns, context);
1799   tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1800   tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1801   context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1802       patterns);
1803 }
1804 
1805 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
1806   auto *context = patterns.getContext();
1807   patterns.add<PushExpandingReshape>(context);
1808 }
1809 
1810 //===---------------------------------------------------------------------===//
1811 // Passes
1812 //===---------------------------------------------------------------------===//
1813 
1814 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
1815                                       OpOperand &consumer) {
1816   if (auto producerCollapseOp =
1817           dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
1818     return !isUnitDimExpansionOnly(producerCollapseOp);
1819   }
1820   if (auto consumerExpandOp =
1821           dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
1822     return !isUnitDimExpansionOnly(consumerExpandOp);
1823   }
1824   return true;
1825 }
1826 
1827 namespace {
1828 
1829 /// Pass that fuses generic ops on tensors. Used only for testing.
1830 struct LinalgElementwiseOpFusionPass
1831     : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
1832   void runOnOperation() override {
1833     Operation *op = getOperation();
1834     RewritePatternSet patterns(op->getContext());
1835     ControlElementwiseOpsFusionFn allowFoldingFn =
1836         [](const OpResult &producer, const OpOperand &consumer) {
1837           return true;
1838         };
1839     populateElementwiseOpsFusionPatterns(
1840         patterns,
1841         LinalgElementwiseFusionOptions().setControlFoldingReshapes(
1842             allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
1843 
1844     // Use TopDownTraversal for compile time reasons
1845     GreedyRewriteConfig grc;
1846     grc.useTopDownTraversal = true;
1847     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
1848                                        grc);
1849   }
1850 };
1851 
1852 /// Pass to test folding of reshape ops with generic ops by linearization.
1853 struct FoldReshapeOpsByLinearizationPass
1854     : public LinalgFoldReshapeOpsByLinearizationBase<
1855           FoldReshapeOpsByLinearizationPass> {
1856   void runOnOperation() override {
1857     Operation *op = getOperation();
1858     RewritePatternSet patterns(op->getContext());
1859     populateFoldReshapeOpsByLinearizationPatterns(patterns);
1860     if (allowFoldingUnitDimReshapes) {
1861       populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
1862     }
1863     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
1864   }
1865 };
1866 
1867 } // namespace
1868 
1869 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
1870   return std::make_unique<LinalgElementwiseOpFusionPass>();
1871 }
1872 
1873 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
1874   return std::make_unique<FoldReshapeOpsByLinearizationPass>();
1875 }
1876