1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include <utility>
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 //===---------------------------------------------------------------------===//
32 // Methods and patterns that fuse elementwise `linalg.generic` operations.
33 //===---------------------------------------------------------------------===//
34 
35 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
36 /// the `producer` to use in the fused operation given the indexing map of the
37 /// result of the producer in the consumer.
getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand * producerOpOperand,AffineMap producerResultIndexMap,AffineMap fusedConsumerArgIndexMap)38 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
39     OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
40     AffineMap fusedConsumerArgIndexMap) {
41   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
42   // from consumer loop -> consumer arg tensor index/producer result tensor
43   // index. The fused loop is same as the consumer loop. For each producer arg
44   // the indexing map to be computed is a map from consumer loop -> producer
45   // arg tensor index.
46   // producerResultIndexMap is a map from producer loop -> tensor index.
47   // Compute the inverse to get map from tensor index -> producer loop.
48   // The inverse is a map from producer result tensor index -> producer loop.
49   AffineMap invProducerResultIndexMap =
50       inversePermutation(producerResultIndexMap);
51   assert(invProducerResultIndexMap &&
52          "expected producer result indexing map to be invertible");
53 
54   LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
55   // argMap is a map from producer loop -> producer arg tensor index.
56   AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
57 
58   // Compose argMap with invProducerResultIndexMap to get a map from
59   // producer result tensor index -> producer arg tensor index.
60   AffineMap t1 = argMap.compose(invProducerResultIndexMap);
61 
62   // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
63   // consumer loop/ fused loop -> producer arg tensor index.
64   return t1.compose(fusedConsumerArgIndexMap);
65 }
66 
67 /// Conditions for elementwise fusion of generic operations.
areElementwiseOpsFusable(GenericOp producer,GenericOp consumer,OpOperand * consumerOpOperand)68 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
69                                      OpOperand *consumerOpOperand) {
70   // Producer and consumer must have tensor semantics.
71   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
72     return false;
73 
74   // Verify that
75   // - the producer has all "parallel" iterator type.
76   if (producer.getNumParallelLoops() != producer.getNumLoops())
77     return false;
78 
79   // Only allow fusing the producer of an input operand for now.
80   // TODO: allow fusing the producer of an output operand.
81   if (!consumer.isInputTensor(consumerOpOperand))
82     return false;
83 
84   // Get the consumer index map. The number of results of the consumer index
85   // map must match the number of loops of the producer.
86   AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
87   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
88     return false;
89 
90   // Currently support only operations with single result.
91   if (producer.getNumOutputs() != 1)
92     return false;
93 
94   // Finally the index_map for the result must be invertible. For now just
95   // verify it is a permutation.
96   AffineMap producerResultIndexMap =
97       producer.getTiedIndexingMap(producer.getOutputOperand(0));
98   if (!producerResultIndexMap.isPermutation())
99     return false;
100 
101   // Ensure that the fusion does not remove size information required to
102   // get the loop bounds. For non-reduction generics, this is trivially the
103   // case due to the output operand. For reductions, we need to check that after
104   // the fusion, each loop dimension has at least one input that defines it.
105   if ((consumer.getNumReductionLoops())) {
106     BitVector coveredDims(consumer.getNumLoops(), false);
107 
108     auto addToCoveredDims = [&](AffineMap map) {
109       for (auto result : map.getResults())
110         if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
111           coveredDims[dimExpr.getPosition()] = true;
112     };
113 
114     for (auto pair :
115          llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
116       Value operand = std::get<0>(pair);
117       if (operand == consumerOpOperand->get())
118         continue;
119       AffineMap operandMap = std::get<1>(pair);
120       addToCoveredDims(operandMap);
121     }
122 
123     for (OpOperand *operand : producer.getInputOperands()) {
124       AffineMap newIndexingMap =
125           getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
126               operand, producerResultIndexMap, consumerIndexMap);
127       addToCoveredDims(newIndexingMap);
128     }
129     if (!coveredDims.all())
130       return false;
131   }
132 
133   return true;
134 }
135 
136 /// Generate the region of the fused tensor operation. The region of the fused
137 /// op must be empty.
138 static void
generateFusedElementwiseOpRegion(PatternRewriter & rewriter,GenericOp fusedOp,AffineMap consumerToProducerLoopsMap,OpOperand * consumerOpOperand,unsigned nloops)139 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
140                                  AffineMap consumerToProducerLoopsMap,
141                                  OpOperand *consumerOpOperand,
142                                  unsigned nloops) {
143   auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
144   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
145   // Build the region of the fused op.
146   Block &producerBlock = producer->getRegion(0).front();
147   Block &consumerBlock = consumer->getRegion(0).front();
148   Block *fusedBlock = new Block();
149   fusedOp.region().push_back(fusedBlock);
150   BlockAndValueMapping mapper;
151   OpBuilder::InsertionGuard guard(rewriter);
152   rewriter.setInsertionPointToStart(fusedBlock);
153 
154   // 2. Add an index operation for every fused loop dimension and use the
155   // `consumerToProducerLoopsMap` to map the producer indices.
156   if (producer.hasIndexSemantics()) {
157     // Add an index operation for every fused loop dimension.
158     unsigned numFusedOpLoops =
159         std::max(producer.getNumLoops(), consumer.getNumLoops());
160     SmallVector<Value> fusedIndices;
161     fusedIndices.reserve(numFusedOpLoops);
162     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
163                     std::back_inserter(fusedIndices), [&](uint64_t dim) {
164                       return rewriter.create<IndexOp>(producer.getLoc(), dim);
165                     });
166     for (IndexOp indexOp :
167          llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
168       Value newIndex = rewriter.create<mlir::AffineApplyOp>(
169           producer.getLoc(),
170           consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
171       mapper.map(indexOp.getResult(), newIndex);
172     }
173   }
174   // TODO: allow fusing the producer of an output operand.
175   assert(consumer.isInputTensor(consumerOpOperand) &&
176          "expected producer of input operand");
177   // 3. Consumer input operands up to consumerIdx (exclusive).
178   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
179            consumerOpOperand->getOperandNumber())) // input assumption.
180     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
181 
182   // Replacing consumerIdx requires getting the cloned, yielded, value from
183   // the (cloned) producer block. This happens in step 9.
184 
185   // 4. Splice in producer's input operands.
186   for (BlockArgument bbArg :
187        producerBlock.getArguments().take_front(producer.getNumInputs()))
188     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
189 
190   // 4.b. Producer output operand/map that is fused needs to be mapped to the
191   // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
192   assert(producer->getNumResults() == 1 && "expected single result producer");
193   if (producer.isInitTensor(producer.getOutputOperand(0))) {
194     BlockArgument bbArg = producerBlock.getArguments()
195                               .drop_front(producer.getNumInputs())
196                               // TODO: bbArg index of
197                               .front();
198     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
199   }
200   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
201   for (BlockArgument bbArg :
202        consumerBlock.getArguments()
203            .take_front(consumer.getNumInputs())
204            .drop_front(consumerOpOperand->getOperandNumber() + 1))
205     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
206   // 6. All of consumer's output operands.
207   for (BlockArgument bbArg :
208        consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
209     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
210   // 7. All of producer's output operands except the one fused.
211   // TODO: allow fusion of multi-result producers.
212   assert(producer->getNumResults() == 1 && "expected single result producer");
213 
214   // 8. Clone all producer operations except for the yield and index operations
215   // to the fused operation.
216   for (auto &op : producerBlock.without_terminator()) {
217     if (!isa<IndexOp>(op))
218       rewriter.clone(op, mapper);
219   }
220   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
221   // forward the yield operand.
222   auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
223   // TODO: allow fusion of multi-result producers.
224   assert(producer->getNumResults() == 1 && "expected single result producer");
225   unsigned producerResultNumber = 0;
226   Value replacement =
227       mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
228   // Sanity checks, if replacement is not already in the mapper then it must be
229   // produced outside.
230   if (replacement == yieldOp.getOperand(producerResultNumber)) {
231     if (auto bb = replacement.dyn_cast<BlockArgument>())
232       assert(bb.getOwner() != &producerBlock &&
233              "yielded block argument must have been mapped");
234     else
235       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
236              "yielded value must have been mapped");
237   }
238   mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
239              replacement);
240   // 10. Clone operations from the consumer to the fused op.
241   for (auto &op : consumerBlock.getOperations())
242     rewriter.clone(op, mapper);
243 
244   // Sanity checks.
245   assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
246          "Ill-formed GenericOp region");
247 }
248 
249 static Optional<SmallVector<Value>>
fuseElementwiseOpsImpl(GenericOp producer,OpOperand * consumerOpOperand,const ControlFusionFn & controlFn,PatternRewriter & rewriter)250 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
251                        const ControlFusionFn &controlFn,
252                        PatternRewriter &rewriter) {
253   auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
254   if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
255       !controlFn(producer->getResult(0), *consumerOpOperand))
256     return llvm::None;
257 
258   // TODO: allow fusing the producer of an output operand.
259   assert(consumer.isInputTensor(consumerOpOperand) &&
260          "expected producer of input operand");
261 
262   // Compute the fused operands list and indexing maps.
263   SmallVector<Value> fusedOperands;
264   SmallVector<AffineMap> fusedIndexMaps;
265   fusedOperands.reserve(producer->getNumOperands() +
266                         consumer->getNumOperands());
267   fusedIndexMaps.reserve(producer->getNumOperands() +
268                          consumer->getNumOperands());
269   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
270   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
271   SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
272   SmallVector<OpOperand *>::iterator it =
273       llvm::find(consumerInputs, consumerOpOperand);
274   assert(it != consumerInputs.end() && "expected to find the consumer operand");
275   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
276     fusedOperands.push_back(opOperand->get());
277     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
278   }
279   // 4. Splice in producer's input operands/maps.
280   assert(producer->getNumResults() == 1 && "expected single result producer");
281   AffineMap producerResultIndexMap =
282       producer.getTiedIndexingMap(producer.getOutputOperand(0));
283   for (OpOperand *opOperand : producer.getInputOperands()) {
284     fusedOperands.push_back(opOperand->get());
285     // Compute indexing maps for the producer args in the fused operation.
286     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
287         opOperand, producerResultIndexMap,
288         consumer.getTiedIndexingMap(consumerOpOperand));
289     fusedIndexMaps.push_back(map);
290   }
291   // 4.b. Producer output operand/map that is fused needs to be passed if it is
292   // an "initTensor" (i.e. its value is actually read).
293   assert(producer->getNumResults() == 1 && "expected single result producer");
294   if (producer.isInitTensor(producer.getOutputOperand(0))) {
295     fusedOperands.push_back(producer.getOutputOperand(0)->get());
296     // Compute indexing maps for the producer args in the fused operation.
297     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
298         producer.getOutputOperand(0), producerResultIndexMap,
299         consumer.getTiedIndexingMap(consumerOpOperand));
300     fusedIndexMaps.push_back(map);
301   }
302   // 5. Remaining consumer's input operands/maps (drop past index
303   // `consumerIdx`).
304   for (OpOperand *opOperand :
305        llvm::make_range(std::next(it), consumerInputs.end())) {
306     fusedOperands.push_back(opOperand->get());
307     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
308   }
309   // 6. All of consumer's output operands (skip operands: added by the builder).
310   for (OpOperand *opOperand : consumer.getOutputOperands())
311     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
312   // 7. All of producer's output operands/maps except the one fused.
313   // TODO: allow fusion of multi-result producers.
314   assert(producer->getNumResults() == 1 && "expected single result producer");
315 
316   // Generate the fused op.
317   SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
318   auto fusedOp = rewriter.create<GenericOp>(
319       consumer.getLoc(), consumer->getResultTypes(),
320       /*inputs=*/fusedOperands,
321       // TODO: handle outputs.
322       consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
323       consumer.iterator_types(),
324       /*doc=*/nullptr,
325       /*library_call=*/nullptr);
326   if (!fusedOp.getShapesToLoopsMap()) {
327     // Fused op has invalid indexing maps. Typically this means something is off
328     // in the input, but going ahead here would result in verification errors.
329     // So cleanup and abort.
330     rewriter.eraseOp(fusedOp);
331     return llvm::None;
332   }
333 
334   // Construct an AffineMap from consumer loops to producer loops.
335   // consumer loop -> tensor index
336   AffineMap consumerResultIndexMap =
337       consumer.getTiedIndexingMap(consumerOpOperand);
338   // tensor index -> producer loop
339   AffineMap invProducerResultIndexMap =
340       inversePermutation(producerResultIndexMap);
341   assert(invProducerResultIndexMap &&
342          "expected producer result indexig map to be invertible");
343   // consumer loop -> producer loop
344   AffineMap consumerToProducerLoopsMap =
345       invProducerResultIndexMap.compose(consumerResultIndexMap);
346 
347   generateFusedElementwiseOpRegion(rewriter, fusedOp,
348                                    consumerToProducerLoopsMap,
349                                    consumerOpOperand, consumer.getNumLoops());
350   return SmallVector<Value>(fusedOp->getResults());
351 }
352 
353 static Optional<SmallVector<Value>>
fuseElementwiseOps(PatternRewriter & rewriter,OpOperand * consumerOpOperand,GenericOp producer,const ControlFusionFn & controlFn)354 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
355                    GenericOp producer, const ControlFusionFn &controlFn) {
356   if (producer->getNumResults() != 1)
357     return llvm::None;
358 
359   return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
360                                 rewriter);
361 }
362 
363 namespace {
364 /// Patterns to fuse a generic op, with the producer of its operands.
365 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
366 public:
FuseElementwiseOps(MLIRContext * context,ControlFusionFn fun,PatternBenefit benefit=1)367   FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
368                      PatternBenefit benefit = 1)
369       : OpRewritePattern<GenericOp>(context, benefit),
370         controlFn(std::move(fun)) {}
371 
matchAndRewrite(GenericOp genericOp,PatternRewriter & rewriter) const372   LogicalResult matchAndRewrite(GenericOp genericOp,
373                                 PatternRewriter &rewriter) const override {
374     // Find the first operand that is defined by another generic op on tensors.
375     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
376       auto producer =
377           dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
378       if (!producer || !producer.hasTensorSemantics())
379         continue;
380       Optional<SmallVector<Value>> fusedOpResults =
381           fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
382       if (fusedOpResults) {
383         rewriter.replaceOp(genericOp, *fusedOpResults);
384         return success();
385       }
386     }
387     return failure();
388   }
389 
390 private:
391   ControlFusionFn controlFn;
392 };
393 } // namespace
394 
395 //===---------------------------------------------------------------------===//
396 // Methods and patterns that fuse reshape ops with elementwise operations by
397 // expanding the dimensionality of the elementwise operations.
398 //===---------------------------------------------------------------------===//
399 
400 /// Conditions for folding a generic operation with a reshape op by expanding
401 /// the iteration space dimensionality for tensor operations. These are
402 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
403 /// following fusion pattern.
404 ///
405 ///  Consider
406 ///
407 ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
408 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
409 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
410 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
411 ///  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
412 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
413 ///
414 ///  The reshape can be folded into the `genericOp` if its loop dimensionality
415 ///  is increased to match the result (operand) of the tensor.expand_shape.
416 ///  The indexing_map of the fused tensor in the `genericOp` and the
417 ///  reassociation map helps compute the indexing maps of the modified op.
418 ///  For the above example, based on the reassociation map it
419 ///  can be concluded that
420 ///
421 ///  - The loop used to access the first dimension of the fused tensor is split
422 ///    into two.
423 ///  - The loop used to access the second dimension of the fused tensor is kept
424 ///    as is.
425 ///  - The loop used to access the third dimension of the fused tensor is split
426 ///    into three.
427 ///
428 ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
429 ///  op, then
430 ///
431 ///   d0 -> e0, e1
432 ///   d1 -> e2, e3, e4
433 ///   d2 -> e5
434 ///
435 ///  substituting this, the generic op can be rewritten as
436 ///
437 ///  %d = linalg.generic ins(%0, %1 : )
438 ///        indexing_maps =
439 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
440 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
441 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
442 ///
443 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
444 ///  to make it consistent
445 ///
446 ///  %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
447 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
448 ///  %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
449 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
450 ///
451 ///  The added reshapes are again expanding patterns, so they will get fused
452 ///  with its producers if possible.
isFusableWithReshapeByDimExpansion(GenericOp genericOp,OpOperand * fusableOpOperand)453 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
454                                                OpOperand *fusableOpOperand) {
455   // Is fusable only if:
456   // - All the indexing maps for operands and results are projected
457   //   permutations.
458   // - The fused tensor is not a scalar.
459   // - All the loops are parallel loops.
460   return genericOp.hasTensorSemantics() &&
461          llvm::all_of(genericOp.indexing_maps().getValue(),
462                       [](Attribute attr) {
463                         return attr.cast<AffineMapAttr>()
464                             .getValue()
465                             .isProjectedPermutation();
466                       }) &&
467          genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
468          llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
469            return attr.cast<StringAttr>().getValue() ==
470                   getParallelIteratorTypeName();
471          });
472 }
473 
474 namespace {
475 /// Information needed to expand a generic operation to fold the reshape with
476 /// it.
477 class ExpansionInfo {
478 public:
479   // Computes the mapping from original dimensions of the op to the dimensions
480   // of the expanded op given the `indexingMap` of the fused operand/result of
481   // the generic op, the `reassocationMaps` of the reshape op and the shape of
482   // the expanded op.
483   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
484                         ArrayRef<AffineMap> reassociationMaps,
485                         ArrayRef<int64_t> expandedShape,
486                         ArrayRef<int64_t> collapsedShape,
487                         PatternRewriter &rewriter);
getOrigOpNumDims() const488   unsigned getOrigOpNumDims() const { return reassociation.size(); }
getExpandedOpNumDims() const489   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
getExpandedDims(unsigned i) const490   ReassociationIndicesRef getExpandedDims(unsigned i) const {
491     return reassociation[i];
492   }
getExpandedShapeOfDim(unsigned i) const493   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
494     return expandedShapeMap[i];
495   }
getOriginalShape() const496   ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
497 
498 private:
499   /// Reassociation from the dimensions in the original operation to the
500   /// dimension of the expanded operation.
501   SmallVector<ReassociationIndices> reassociation;
502   /// Mapping from extent of loops in the original operation, to the extent of
503   /// loops in the expanded operation.
504   SmallVector<SmallVector<int64_t>> expandedShapeMap;
505   /// Extent of the loop in the original operation.
506   SmallVector<int64_t> originalLoopExtent;
507   unsigned expandedOpNumDims;
508 };
509 } // namespace
510 
compute(LinalgOp linalgOp,OpOperand * fusableOpOperand,ArrayRef<AffineMap> reassociationMaps,ArrayRef<int64_t> expandedShape,ArrayRef<int64_t> collapsedShape,PatternRewriter & rewriter)511 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
512                                      OpOperand *fusableOpOperand,
513                                      ArrayRef<AffineMap> reassociationMaps,
514                                      ArrayRef<int64_t> expandedShape,
515                                      ArrayRef<int64_t> collapsedShape,
516                                      PatternRewriter &rewriter) {
517   if (reassociationMaps.empty())
518     return failure();
519   AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
520 
521   SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
522   originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
523 
524   reassociation.clear();
525   expandedShapeMap.clear();
526   // Compute the number of dimension in the expanded op that correspond to each
527   // dimension of the original op.
528   SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
529   expandedShapeMap.resize(fusedIndexMap.getNumDims());
530   for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
531     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
532     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
533     numExpandedDims[pos] = foldedDims.getNumResults();
534     ArrayRef<int64_t> shape =
535         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
536     expandedShapeMap[pos].assign(shape.begin(), shape.end());
537   }
538   // The remaining dimensions remain the same.
539   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
540     if (expandedShapeMap[i].empty())
541       expandedShapeMap[i] = {originalLoopExtent[i]};
542 
543   // Compute reassociation map from the original op to the expanded op.
544   unsigned sum = 0;
545   reassociation.reserve(fusedIndexMap.getNumDims());
546   for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
547     auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
548     reassociation.emplace_back(seq.begin(), seq.end());
549     sum += numFoldedDim.value();
550   }
551   expandedOpNumDims = sum;
552   return success();
553 }
554 
555 /// Epanding the body of a linalg operation requires adaptations of the accessed
556 /// loop indices. Specifically, access of indices in the original operation need
557 /// to be replaced with linearizations of indices in the expanded op. That
558 /// requires the shape of the expanded dimensions to be static (at least all but
559 /// the most significant). For now check that these are all statically sized.
560 /// Note that this could be extended to handle dynamic case, but the
561 /// implementation below uses `affine.apply` which seems to have issues when the
562 /// shapes are not static.
isGenericOpExpandable(GenericOp genericOp,const ExpansionInfo & expansionInfo,PatternRewriter & rewriter)563 static LogicalResult isGenericOpExpandable(GenericOp genericOp,
564                                            const ExpansionInfo &expansionInfo,
565                                            PatternRewriter &rewriter) {
566   if (!genericOp.hasIndexSemantics())
567     return success();
568   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
569     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
570     if (expandedShape.size() == 1)
571       continue;
572     for (int64_t shape : expandedShape.drop_front()) {
573       if (ShapedType::isDynamic(shape)) {
574         return rewriter.notifyMatchFailure(
575             genericOp, "cannot expand due to index semantics and dynamic dims");
576       }
577     }
578   }
579   return success();
580 }
581 
582 /// Return the indexing map to use in the expanded op for a given the
583 /// `indexingMap` of the original operation.
584 static AffineMap
getIndexingMapInExpandedOp(OpBuilder & builder,AffineMap indexingMap,const ExpansionInfo & expansionInfo)585 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
586                            const ExpansionInfo &expansionInfo) {
587   SmallVector<AffineExpr> newExprs;
588   for (AffineExpr expr : indexingMap.getResults()) {
589     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
590     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
591         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
592           return builder.getAffineDimExpr(static_cast<unsigned>(v));
593         }));
594     newExprs.append(expandedExprs.begin(), expandedExprs.end());
595   }
596   return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
597                         indexingMap.getNumSymbols(), newExprs,
598                         builder.getContext());
599 }
600 
601 /// Return the type of the operand/result to use in the expanded op given the
602 /// type in the original op.
getExpandedType(RankedTensorType originalType,AffineMap indexingMap,const ExpansionInfo & expansionInfo)603 static RankedTensorType getExpandedType(RankedTensorType originalType,
604                                         AffineMap indexingMap,
605                                         const ExpansionInfo &expansionInfo) {
606   SmallVector<int64_t> expandedShape;
607   for (AffineExpr expr : indexingMap.getResults()) {
608     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
609     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
610     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
611   }
612   return RankedTensorType::get(expandedShape, originalType.getElementType());
613 }
614 
615 /// Returns the reassociation maps to use in the `tensor.expand_shape`
616 /// operation to convert the operands of the original operation to operands of
617 /// the expanded operation. The same method is used to compute the
618 /// `tensor.collapse_shape` used to collapse the result of the expanded
619 /// op to get the value that can replace all uses of the results of the original
620 /// op.
621 static SmallVector<ReassociationIndices>
getReassociationForExpansion(AffineMap indexingMap,const ExpansionInfo & expansionInfo)622 getReassociationForExpansion(AffineMap indexingMap,
623                              const ExpansionInfo &expansionInfo) {
624   SmallVector<ReassociationIndices> reassociation;
625   unsigned numReshapeDims = 0;
626   for (AffineExpr expr : indexingMap.getResults()) {
627     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
628     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
629     SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
630         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
631     reassociation.emplace_back(std::move(indices));
632     numReshapeDims += numExpandedDims;
633   }
634   return reassociation;
635 }
636 
637 /// Update the body of an expanded linalg operation having index semantics. The
638 /// indices of the original operation need to be recovered by linearizing the
639 /// indices of the correspoding dimensions of the expanded operation. For now it
640 /// is assumed that the shapes of the expanded operation needed for
641 /// linearization are static.
updateExpandedGenericOpRegion(PatternRewriter & rewriter,Location loc,Region & fusedRegion,const ExpansionInfo & expansionInfo)642 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
643                                           Location loc, Region &fusedRegion,
644                                           const ExpansionInfo &expansionInfo) {
645   // Replace the original indices by the linearization of the expanded indices.
646   for (IndexOp indexOp :
647        llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
648     ArrayRef<int64_t> expandedDims =
649         expansionInfo.getExpandedDims(indexOp.dim());
650     assert(!expandedDims.empty() && "expected valid expansion info");
651 
652     // Skip index operations that are not affected by the expansion.
653     if (expandedDims.size() == 1 &&
654         expandedDims.front() == (int64_t)indexOp.dim())
655       continue;
656 
657     // Linearize the expanded indices of the original index dimension.
658     OpBuilder::InsertionGuard guard(rewriter);
659     rewriter.setInsertionPointAfter(indexOp);
660     ArrayRef<int64_t> expandedDimsShape =
661         expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
662     SmallVector<Value> expandedIndices;
663     expandedIndices.reserve(expandedDims.size() - 1);
664     llvm::transform(
665         expandedDims.drop_front(), std::back_inserter(expandedIndices),
666         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
667     Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
668     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
669       assert(!ShapedType::isDynamic(std::get<0>(it)));
670       AffineExpr idx, acc;
671       bindDims(rewriter.getContext(), idx, acc);
672       newIndex = rewriter.create<AffineApplyOp>(
673           indexOp.getLoc(), idx + acc * std::get<0>(it),
674           ValueRange{std::get<1>(it), newIndex});
675     }
676     rewriter.replaceOp(indexOp, newIndex);
677   }
678 }
679 
680 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
681 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
682 /// that those conditions have been satisfied.
683 static Optional<SmallVector<Value>>
fuseWithReshapeByExpansion(GenericOp genericOp,Operation * reshapeOp,OpOperand * fusableOpOperand,PatternRewriter & rewriter)684 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
685                            OpOperand *fusableOpOperand,
686                            PatternRewriter &rewriter) {
687   assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
688          "preconditions for fuse operation failed");
689   // Check if reshape is expanding or collapsing.
690   auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
691   auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
692   bool isExpanding = (expandingReshapeOp != nullptr);
693   RankedTensorType expandedType = isExpanding
694                                       ? expandingReshapeOp.getResultType()
695                                       : collapsingReshapeOp.getSrcType();
696   RankedTensorType collapsedType = isExpanding
697                                        ? expandingReshapeOp.getSrcType()
698                                        : collapsingReshapeOp.getResultType();
699 
700   ExpansionInfo expansionInfo;
701   if (failed(expansionInfo.compute(
702           genericOp, fusableOpOperand,
703           isExpanding ? expandingReshapeOp.getReassociationMaps()
704                       : collapsingReshapeOp.getReassociationMaps(),
705           expandedType.getShape(), collapsedType.getShape(), rewriter)))
706     return llvm::None;
707 
708   if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
709     return llvm::None;
710 
711   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
712       llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap m) {
713         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
714       }));
715 
716   SmallVector<Value> expandedOpOperands;
717   expandedOpOperands.reserve(genericOp.getNumInputs());
718   for (OpOperand *opOperand : genericOp.getInputOperands()) {
719     if (opOperand == fusableOpOperand) {
720       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
721                                                : collapsingReshapeOp.getSrc());
722       continue;
723     }
724     if (genericOp.isInputTensor(opOperand)) {
725       AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
726       auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
727       RankedTensorType expandedOperandType =
728           getExpandedType(opOperandType, indexingMap, expansionInfo);
729       if (expandedOperandType != opOperand->get().getType()) {
730         // Reshape the operand to get the right type.
731         SmallVector<ReassociationIndices> reassociation =
732             getReassociationForExpansion(indexingMap, expansionInfo);
733         if (failed(reshapeLikeShapesAreCompatible(
734                 [&](const Twine &msg) {
735                   return rewriter.notifyMatchFailure(genericOp, msg);
736                 },
737                 opOperandType.getShape(), expandedOperandType.getShape(),
738                 reassociation,
739                 /*isExpandingReshape=*/true)))
740           return llvm::None;
741         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
742             genericOp.getLoc(), expandedOperandType, opOperand->get(),
743             reassociation));
744         continue;
745       }
746     }
747     expandedOpOperands.push_back(opOperand->get());
748   }
749 
750   Location loc = genericOp.getLoc();
751   SmallVector<Value> outputs;
752   for (OpOperand *opOperand : genericOp.getOutputOperands()) {
753     AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
754     auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
755     RankedTensorType expandedOutputType =
756         getExpandedType(opOperandType, indexingMap, expansionInfo);
757     if (expandedOutputType != opOperand->get().getType()) {
758       SmallVector<ReassociationIndices> reassociation =
759           getReassociationForExpansion(indexingMap, expansionInfo);
760       if (failed(reshapeLikeShapesAreCompatible(
761               [&](const Twine &msg) {
762                 return rewriter.notifyMatchFailure(genericOp, msg);
763               },
764               opOperandType.getShape(), expandedOutputType.getShape(),
765               reassociation,
766               /*isExpandingReshape=*/true)))
767         return llvm::None;
768       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
769           genericOp.getLoc(), expandedOutputType, opOperand->get(),
770           reassociation));
771     }
772   }
773 
774   // The iterator types of the expanded op are all parallel.
775   SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
776                                        getParallelIteratorTypeName());
777 
778   TypeRange resultTypes = ValueRange(outputs).getTypes();
779   auto fusedOp =
780       rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
781                                  /*inputs=*/expandedOpOperands, outputs,
782                                  expandedOpIndexingMaps, iteratorTypes);
783   Region &fusedRegion = fusedOp->getRegion(0);
784   Region &originalRegion = genericOp->getRegion(0);
785   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
786 
787   // Update the index accesses after the expansion.
788   updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
789 
790   // Reshape the result values to their original shape if this is a collapsing
791   // reshape folded into its consumer.
792   SmallVector<Value> resultVals;
793   for (OpResult opResult : genericOp->getOpResults()) {
794     int64_t resultNumber = opResult.getResultNumber();
795     if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
796       SmallVector<ReassociationIndices> reassociation =
797           getReassociationForExpansion(
798               genericOp.getTiedIndexingMap(
799                   genericOp.getOutputOperand(resultNumber)),
800               expansionInfo);
801       resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
802           genericOp.getLoc(), opResult.getType(),
803           fusedOp->getResult(resultNumber), reassociation));
804     } else {
805       resultVals.push_back(fusedOp->getResult(resultNumber));
806     }
807   }
808   // Assuming a single result.
809   return resultVals;
810 }
811 
812 namespace {
813 
814 /// Pattern to fuse a tensor.collapse_shape op with its consumer generic op,
815 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
816 /// in the consumer is expanded.
817 class FoldWithProducerReshapeOpByExpansion
818     : public OpRewritePattern<GenericOp> {
819 public:
FoldWithProducerReshapeOpByExpansion(MLIRContext * context,ControlFusionFn foldReshapes,PatternBenefit benefit=1)820   FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
821                                        ControlFusionFn foldReshapes,
822                                        PatternBenefit benefit = 1)
823       : OpRewritePattern<GenericOp>(context, benefit),
824         controlFoldingReshapes(std::move(foldReshapes)) {}
825 
matchAndRewrite(GenericOp genericOp,PatternRewriter & rewriter) const826   LogicalResult matchAndRewrite(GenericOp genericOp,
827                                 PatternRewriter &rewriter) const override {
828     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
829       tensor::CollapseShapeOp reshapeOp =
830           opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
831       if (!reshapeOp)
832         continue;
833       // Fold only if
834       // - The tensor reshape op is folding.
835       // - All constraints of fusing with reshape by expansion are met.
836       if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
837           (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
838         continue;
839 
840       Optional<SmallVector<Value>> replacementValues =
841           fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
842       if (!replacementValues)
843         return failure();
844       rewriter.replaceOp(genericOp, *replacementValues);
845       return success();
846     }
847     return failure();
848   }
849 
850 private:
851   ControlFusionFn controlFoldingReshapes;
852 };
853 
854 /// Pattern to fold a tensor.expand_shape op with its producer generic op
855 /// by expanding the dimensionality of the loop in the producer op.
856 struct FoldReshapeWithGenericOpByExpansion
857     : public OpRewritePattern<tensor::ExpandShapeOp> {
858 
FoldReshapeWithGenericOpByExpansion__anon750e3d170c11::FoldReshapeWithGenericOpByExpansion859   FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
860                                       ControlFusionFn foldReshapes,
861                                       PatternBenefit benefit = 1)
862       : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
863         controlFoldingReshapes(std::move(foldReshapes)) {}
864 
matchAndRewrite__anon750e3d170c11::FoldReshapeWithGenericOpByExpansion865   LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
866                                 PatternRewriter &rewriter) const override {
867     // Fold only if all constraints of fusing with reshape by expansion are met.
868     GenericOp producer = reshapeOp.getSrc().getDefiningOp<GenericOp>();
869     if (!producer || producer.getNumOutputs() != 1 ||
870         !isFusableWithReshapeByDimExpansion(producer,
871                                             producer.getOutputOperand(0)) ||
872         !controlFoldingReshapes(producer->getResult(0),
873                                 reshapeOp->getOpOperand(0)))
874       return failure();
875     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
876         producer, reshapeOp, producer.getOutputOperand(0), rewriter);
877     if (!replacementValues)
878       return failure();
879     rewriter.replaceOp(reshapeOp, *replacementValues);
880     return success();
881   }
882 
883 private:
884   ControlFusionFn controlFoldingReshapes;
885 };
886 } // namespace
887 
888 //===---------------------------------------------------------------------===//
889 // Methods and patterns to fuse reshape with linalg.generic operations by
890 // contraction of dimensions.
891 //===---------------------------------------------------------------------===//
892 
893 /// For a given list of indices in the range of the `indexingMap` that are
894 /// folded, return the indices of the corresponding domain. Return `llvm::None`
895 /// on failure. Ensures that all the elements of the returned reassociation are
896 /// distinct.
897 static ReassociationIndices
getDomainReassociation(AffineMap indexingMap,ReassociationIndicesRef rangeReassociation)898 getDomainReassociation(AffineMap indexingMap,
899                        ReassociationIndicesRef rangeReassociation) {
900   assert(indexingMap.isProjectedPermutation() &&
901          "expected projected permutation");
902 
903   ReassociationIndices domainReassociation = llvm::to_vector<4>(
904       llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
905         return indexingMap.getResults()[pos]
906             .cast<AffineDimExpr>()
907             .getPosition();
908       }));
909   // The projected permutation semantics ensures that there is no repetition of
910   // the domain indices.
911   return domainReassociation;
912 }
913 
914 /// For a given `dimSequence`, check if the sequence is conserved in the
915 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
916 /// Non-existence of the sequence returns true as well.
isDimSequencePreserved(AffineMap indexingMap,ReassociationIndicesRef dimSequence)917 static bool isDimSequencePreserved(AffineMap indexingMap,
918                                    ReassociationIndicesRef dimSequence) {
919   assert(!dimSequence.empty() &&
920          "expected non-empty list for dimension sequence");
921   assert(indexingMap.isProjectedPermutation() &&
922          "expected indexing map to be projected permutation");
923 
924   llvm::SmallDenseSet<unsigned, 4> sequenceElements;
925   sequenceElements.insert(dimSequence.begin(), dimSequence.end());
926 
927   unsigned dimSequenceStart = dimSequence[0];
928   for (const auto &expr : enumerate(indexingMap.getResults())) {
929     unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition();
930     // 1.  Check if this start of the sequence.
931     if (dimInMapStart == dimSequenceStart) {
932       if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
933         return false;
934       // 1a. Check if sequence is preserved.
935       for (const auto &dimInSequence : enumerate(dimSequence)) {
936         unsigned dimInMap =
937             indexingMap.getResult(expr.index() + dimInSequence.index())
938                 .cast<AffineDimExpr>()
939                 .getPosition();
940         if (dimInMap != dimInSequence.value())
941           return false;
942       }
943       // Found the sequence. Projected permutation
944       // enforces that all AffineDimExprs in the result are unique, so no
945       // further checks are needed.
946       return true;
947     }
948     // 2. If position in the expr (which is of type AffineDimExpr) is part
949     // of sequence, return false here. This implies the entire sequence does not
950     // exist in the indexing map.
951     if (sequenceElements.count(dimInMapStart))
952       return false;
953   }
954   // 3. No element of sequence found. Return true.
955   return true;
956 }
957 
958 // Return the list of dimensions of the iteration domain that can be
959 // collapsed to allow for fusion with the a producer that is an expand_shape
960 // operation. If all dimensions created by expansion can be collapsed in the
961 // iteration space then the reshape is defunct.
962 //
963 // Example:
964 //
965 // ```mlir
966 // #map = affine_map<(d0, d1) -> (d0, d1)>
967 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
968 // %2 = linalg.init_tensor [..] : tensor<?x4xf32>
969 // %3 = linalg.generic {
970 //     indexing_maps = [#map, #map],
971 //     iterator_types = ["parallel" ,"parallel"]}
972 //     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
973 // ```
974 //
975 // can be fused by collapsing the dimensions of the iteration space.
976 //
977 // ```mlir
978 // #map = affine_map<(d0) -> (d0)>
979 // %2 = linalg.init_tensor [..] : tensor<?xf32>
980 // %3 = linalg.generic {
981 //     indexing_maps = [#map, #map],
982 //     iterator_types = ["parallel"]}
983 //     ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
984 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
985 // ```
986 //
987 // In the following example,
988 //
989 // ```mlir
990 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
991 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
992 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
993 // %2 = linalg.init_tensor [..] : tensor<4x?xf32>
994 // %2 = linalg.generic {
995 //     indexing_maps = [#map0, #map1],
996 //     iterator_types = ["parallel" ,"parallel"]}
997 //     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
998 // ```
999 //
1000 // the reshape cannot be fused with the generic op by collapsing the op
1001 // dimensions since the indexing maps will have to contain mods and divs
1002 // to preserve the accesses pattern. When no dimensions of the iteration
1003 // space are collapsable and empty vector is returned.
1004 static SmallVector<ReassociationIndices>
getCollapsableIterationSpaceDims(GenericOp genericOp,OpOperand * fusableOperand,ArrayRef<ReassociationIndices> reassociation)1005 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1006                                  ArrayRef<ReassociationIndices> reassociation) {
1007   // Some basic checks for this fusion to be valid.
1008   if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1)
1009     return {};
1010 
1011   if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1012         return map.isProjectedPermutation();
1013       })) {
1014     return {};
1015   }
1016 
1017   // Compute all the loops with the reduction iterator types.
1018   SmallVector<int64_t> reductionDims;
1019   for (const auto &iteratorType : llvm::enumerate(genericOp.iterator_types())) {
1020     if (isReductionIterator(iteratorType.value())) {
1021       reductionDims.push_back(iteratorType.index());
1022     }
1023   }
1024 
1025   llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1026   AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand);
1027   auto iteratorTypes = genericOp.iterator_types().getValue();
1028   SmallVector<ReassociationIndices> iterationSpaceReassociation;
1029   for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1030     assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1031 
1032     // Ignore dims that are not folded.
1033     if (foldedRangeDims.size() == 1)
1034       continue;
1035 
1036     ReassociationIndices foldedIterationSpaceDims =
1037         getDomainReassociation(indexingMap, foldedRangeDims);
1038 
1039     // Check that the folded iteration dims do not contain already processed
1040     // dims.
1041     if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1042           return processedIterationDims.count(dim);
1043         }))
1044       continue;
1045 
1046     // Check that all folded iterator types are all parallel or all reductions.
1047     Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
1048     if (!isParallelIterator(startIteratorType) &&
1049         !isReductionIterator(startIteratorType))
1050       continue;
1051     if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1052           return iteratorTypes[dim] != startIteratorType;
1053         }))
1054       continue;
1055 
1056     // If the folded dimensions correspond to a "reduction" iterator type,
1057     // the folded dimensions need to be "in-order". Strictly speaking this is
1058     // not necessary, for reductions that are associative and commutative,  but
1059     // using a more strict definition of reduction for now.
1060     if (isReductionIterator(startIteratorType)) {
1061       bool isContiguous = false;
1062       for (const auto &startDim : llvm::enumerate(reductionDims)) {
1063         // Move window in `reductionDims` to start of the folded iteration dims.
1064         if (startDim.value() != foldedIterationSpaceDims[0])
1065           continue;
1066         // If sizes doesnt match, trivial not contiguous. This condition should
1067         // not be hit.
1068         if (startDim.index() + foldedIterationSpaceDims.size() >
1069             reductionDims.size())
1070           break;
1071         // Check that the contiguity is maintained.
1072         isContiguous = true;
1073         for (const auto &foldedDim :
1074              llvm::enumerate(foldedIterationSpaceDims)) {
1075           if (reductionDims[foldedDim.index() + startDim.index()] !=
1076               foldedDim.value()) {
1077             isContiguous = false;
1078             break;
1079           }
1080         }
1081         break;
1082       }
1083       if (!isContiguous)
1084         continue;
1085     }
1086 
1087     // Check that the sequence is preserved in all indexing maps.
1088     if (llvm::any_of(genericOp.getIndexingMapsArray(),
1089                      [&](AffineMap indexingMap) {
1090                        return !isDimSequencePreserved(indexingMap,
1091                                                       foldedIterationSpaceDims);
1092                      }))
1093       continue;
1094 
1095     processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1096                                   foldedIterationSpaceDims.end());
1097     iterationSpaceReassociation.emplace_back(
1098         std::move(foldedIterationSpaceDims));
1099   }
1100 
1101   return iterationSpaceReassociation;
1102 }
1103 
1104 /// Helper class to carry state while collapsing the `linalg.generic` op.
1105 namespace {
1106 class CollapsingInfo {
1107 public:
initialize(unsigned origNumLoops,ArrayRef<ReassociationIndices> foldedIterationDims)1108   LogicalResult initialize(unsigned origNumLoops,
1109                            ArrayRef<ReassociationIndices> foldedIterationDims) {
1110     llvm::SmallDenseSet<int64_t, 4> processedDims;
1111     // Find all the dims that are folded.
1112     for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1113       if (foldedIterationDim.empty())
1114         continue;
1115       // If the folded dims contain dims already folded, that's illegal
1116       // specification. Repetition within a list is also illegal.
1117       for (auto dim : foldedIterationDim) {
1118         if (dim >= origNumLoops)
1119           return failure();
1120         if (processedDims.count(dim))
1121           return failure();
1122         processedDims.insert(dim);
1123       }
1124       collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1125                                                    foldedIterationDim.end());
1126     }
1127     if (processedDims.size() > origNumLoops)
1128       return failure();
1129 
1130     // Add all the preserved dims of the original op as single
1131     // elements to `collapsedOpToOrigOpIterationDim`.
1132     for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1133       if (processedDims.count(dim))
1134         continue;
1135       collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1136     }
1137 
1138     llvm::sort(collapsedOpToOrigOpIterationDim,
1139                [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
1140                  return lhs[0] < rhs[0];
1141                });
1142     origOpToCollapsedOpIterationDim.resize(origNumLoops);
1143     for (const auto &foldedDims :
1144          llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1145       for (const auto &dim : enumerate(foldedDims.value()))
1146         origOpToCollapsedOpIterationDim[dim.value()] =
1147             std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1148     }
1149     return success();
1150   }
1151 
1152   /// Return mapping from collapsed loop domain to original loop domain.
getCollapsedOpToOrigOpMapping() const1153   ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1154     return collapsedOpToOrigOpIterationDim;
1155   }
1156 
1157   /// Return mapping from original loop domain to collapsed loop domain. The
1158   /// mapping is a pair. First value is the dimension in the collapsed loop that
1159   /// the original loop is mapped to. Second is the relative position in folded
1160   /// list of this domain. For example if the original loop domain is 3D, and
1161   /// the collapsed loop domain is folding all of it, i.e.
1162   ///
1163   /// ```
1164   /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1165   /// ```
1166   ///
1167   /// then
1168   ///
1169   /// ```
1170   ///  origOpToCollapsedOpMapping[0] = {0, 0};
1171   ///  origOpToCollapsedOpMapping[1] = {0, 1};
1172   ///  origOpToCollapsedOpMapping[2] = {0, 2};
1173   ///  origOpToCollapsedOpMapping[3] = {1, 0};
1174   ///  origOpToCollapsedOpMapping[4] = {1, 1};
1175   /// ```
1176   ///
getOrigOpToCollapsedOpMapping() const1177   ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1178     return origOpToCollapsedOpIterationDim;
1179   }
1180 
1181   /// Return the collapsed op iteration domain rank.
getCollapsedOpIterationRank() const1182   unsigned getCollapsedOpIterationRank() const {
1183     return collapsedOpToOrigOpIterationDim.size();
1184   }
1185 
1186 private:
1187   /// Map from the iteration domain index in collapsed op to the iteration
1188   /// domain indices in the original op.
1189   SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1190 
1191   /// Map from iteration domain index in the original op to the iteration domain
1192   /// index in the collapsed op.
1193   SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1194 };
1195 } // namespace
1196 
1197 /// Get the iterator types for the collapsed operation given the original
1198 /// iterator types and collapsed dimensions.
1199 static SmallVector<StringRef>
getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,const CollapsingInfo & collapsingInfo)1200 getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
1201                             const CollapsingInfo &collapsingInfo) {
1202   SmallVector<StringRef> collapsedIteratorTypes;
1203   for (ReassociationIndicesRef foldedIterDims :
1204        collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1205     assert(!foldedIterDims.empty() &&
1206            "reassociation indices expected to have non-empty sets");
1207     // Just pick the iterator type of the first folded dim. Pre-condition checks
1208     // expected to have checked that iterator types of all folded dimensions are
1209     // the same.
1210     collapsedIteratorTypes.push_back(
1211         iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue());
1212   }
1213   return collapsedIteratorTypes;
1214 }
1215 
1216 /// Compute the indexing map in the collapsed op that corresponds to the given
1217 /// `indexingMap` of the original operation.
1218 static AffineMap
getCollapsedOpIndexingMap(AffineMap indexingMap,const CollapsingInfo & collapsingInfo)1219 getCollapsedOpIndexingMap(AffineMap indexingMap,
1220                           const CollapsingInfo &collapsingInfo) {
1221   MLIRContext *context = indexingMap.getContext();
1222   assert(indexingMap.isProjectedPermutation() &&
1223          "expected indexing map to be projected permutation");
1224   SmallVector<AffineExpr> resultExprs;
1225   auto origOpToCollapsedOpMapping =
1226       collapsingInfo.getOrigOpToCollapsedOpMapping();
1227   for (auto expr : indexingMap.getResults()) {
1228     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
1229     // If the dim is not the first of the collapsed dim, do nothing.
1230     if (origOpToCollapsedOpMapping[dim].second != 0)
1231       continue;
1232     // The next n-dims are guaranteed to be collapsed. So just use the
1233     // iteration dimension of the collapsed op.
1234     resultExprs.push_back(
1235         getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1236   }
1237   return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1238                         resultExprs, context);
1239 }
1240 
1241 /// Return the `reassociation` indices to use to collapse the operand when the
1242 /// iteration space of a generic op is collapsed.
1243 static SmallVector<ReassociationIndices>
getOperandReassociation(AffineMap indexingMap,const CollapsingInfo & collapsingInfo)1244 getOperandReassociation(AffineMap indexingMap,
1245                         const CollapsingInfo &collapsingInfo) {
1246   unsigned counter = 0;
1247   SmallVector<ReassociationIndices> operandReassociation;
1248   auto origOpToCollapsedOpMapping =
1249       collapsingInfo.getOrigOpToCollapsedOpMapping();
1250   auto collapsedOpToOrigOpMapping =
1251       collapsingInfo.getCollapsedOpToOrigOpMapping();
1252   while (counter < indexingMap.getNumResults()) {
1253     unsigned dim =
1254         indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition();
1255     if (origOpToCollapsedOpMapping[dim].second == 0) {
1256       // This is the start of a collapsed dimensions of the iteration that
1257       // is gauranteed to be preserved in the indexing map. The number of folded
1258       // dims is obtained from the collapsed op to original op mapping.
1259       unsigned numFoldedDims =
1260           collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1261               .size();
1262       auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1263       operandReassociation.emplace_back(range.begin(), range.end());
1264       counter += numFoldedDims;
1265     }
1266   }
1267   return operandReassociation;
1268 }
1269 
1270 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
getCollapsedOpOperand(Location loc,GenericOp genericOp,OpOperand * opOperand,const CollapsingInfo & collapsingInfo,OpBuilder & builder)1271 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
1272                                    OpOperand *opOperand,
1273                                    const CollapsingInfo &collapsingInfo,
1274                                    OpBuilder &builder) {
1275   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
1276   SmallVector<ReassociationIndices> operandReassociation =
1277       getOperandReassociation(indexingMap, collapsingInfo);
1278 
1279   // If the number of entries in the reassocation for the operand is same as the
1280   // number of results of the indexing map, then nothing to do for this operand.
1281   Value operand = opOperand->get();
1282   if (operandReassociation.size() == indexingMap.getNumResults())
1283     return operand;
1284 
1285   // Insert a reshape to collapse the dimensions.
1286   auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
1287       loc, operand, operandReassociation);
1288   return reshapeOp.getResult();
1289 }
1290 
1291 /// Modify the `linalg.index` operations in the original generic op, to its
1292 /// value in the collapsed operation.
generateCollapsedIndexingRegion(Location loc,Block * block,const CollapsingInfo & collapsingInfo,ValueRange loopRange,PatternRewriter & rewriter)1293 void generateCollapsedIndexingRegion(Location loc, Block *block,
1294                                      const CollapsingInfo &collapsingInfo,
1295                                      ValueRange loopRange,
1296                                      PatternRewriter &rewriter) {
1297   OpBuilder::InsertionGuard g(rewriter);
1298   rewriter.setInsertionPointToStart(block);
1299 
1300   // Collect all the original index ops.
1301   auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1302 
1303   // For each folded dimension list resolve the original induction variable
1304   // values in terms of the folded dimension induction variable.
1305   //   i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1306   // can be inverted to
1307   //   i2 = i_{folded} % d2
1308   //   i1 = (i_{folded} / d2) % d1
1309   //   i0 = i_{folded} / (d1 * d2)
1310   llvm::DenseMap<unsigned, Value> indexReplacementVals;
1311   for (auto &foldedDims :
1312        enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1313     ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1314     Value newIndexVal =
1315         rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1316     for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1317       indexReplacementVals[dim] =
1318           rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1319       newIndexVal =
1320           rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1321     }
1322     indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1323   }
1324 
1325   for (auto indexOp : indexOps) {
1326     auto dim = indexOp.dim();
1327     rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1328   }
1329 }
1330 
1331 /// Implementation of fusion with reshape operation by collapsing dimensions.
collapseGenericOpIterationDims(GenericOp genericOp,ArrayRef<ReassociationIndices> foldedIterationDims,OpOperand * fusableOpOperand,PatternRewriter & rewriter)1332 static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
1333     GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
1334     OpOperand *fusableOpOperand, PatternRewriter &rewriter) {
1335   // Bail on trivial no-op cases.
1336   if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1337       llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1338         return foldedDims.size() <= 1;
1339       }))
1340     return failure();
1341 
1342   CollapsingInfo collapsingInfo;
1343   if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
1344                                        foldedIterationDims))) {
1345     return rewriter.notifyMatchFailure(
1346         genericOp, "illegal to collapse specified dimensions");
1347   }
1348 
1349   // Get the iterator types for the operand.
1350   SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes(
1351       genericOp.iterator_types().getValue(), collapsingInfo);
1352 
1353   // Get the indexing maps.
1354   auto indexingMaps = llvm::to_vector(
1355       llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
1356         return getCollapsedOpIndexingMap(map, collapsingInfo);
1357       }));
1358 
1359   Location loc = genericOp->getLoc();
1360 
1361   // Get the input operands.
1362   auto inputOperands = llvm::to_vector(
1363       llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) {
1364         return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
1365                                      rewriter);
1366       }));
1367 
1368   // Get the output operands and result types.
1369   SmallVector<Type> resultTypes;
1370   SmallVector<Value> outputOperands;
1371   resultTypes.reserve(genericOp.getNumOutputs());
1372   outputOperands.reserve(genericOp.getNumOutputs());
1373   for (OpOperand *output : genericOp.getOutputOperands()) {
1374     Value newOutput =
1375         getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter);
1376     outputOperands.push_back(newOutput);
1377     resultTypes.push_back(newOutput.getType());
1378   }
1379 
1380   // Create the generic op.
1381   auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
1382       loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1383       iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1384   Block *origOpBlock = &genericOp->getRegion(0).front();
1385   Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
1386   rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1387                        collapsedOpBlock->getArguments());
1388 
1389   if (collapsedGenericOp.hasIndexSemantics()) {
1390     // Collect the loop range of the generic op.
1391     OpBuilder::InsertionGuard g(rewriter);
1392     rewriter.setInsertionPoint(collapsedGenericOp);
1393     SmallVector<Range> loopRanges =
1394         cast<LinalgOp>(genericOp.getOperation())
1395             .createLoopRanges(rewriter, genericOp.getLoc());
1396     assert(llvm::all_of(loopRanges,
1397                         [](Range range) {
1398                           return matchPattern(range.offset, m_Zero()) &&
1399                                  matchPattern(range.stride, m_One());
1400                         }) &&
1401            "expected all loop ranges to have zero start and unit stride");
1402     SmallVector<Value> loopBound = llvm::to_vector(
1403         llvm::map_range(loopRanges, [](Range range) { return range.size; }));
1404     generateCollapsedIndexingRegion(loc,
1405                                     &collapsedGenericOp->getRegion(0).front(),
1406                                     collapsingInfo, loopBound, rewriter);
1407   }
1408 
1409   // Insert expanding reshape for the result to get back the original result
1410   // type.
1411   SmallVector<Value> results;
1412   for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
1413     Value collapsedOpResult =
1414         collapsedGenericOp->getResult(originalResult.index());
1415     auto originalResultType =
1416         originalResult.value().getType().cast<ShapedType>();
1417     auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
1418     if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1419       AffineMap indexingMap =
1420           genericOp.getTiedIndexingMapForResult(originalResult.value());
1421       SmallVector<ReassociationIndices> reassociation =
1422           getOperandReassociation(indexingMap, collapsingInfo);
1423       Value result = rewriter.create<tensor::ExpandShapeOp>(
1424           loc, originalResultType, collapsedOpResult, reassociation);
1425       results.push_back(result);
1426     } else {
1427       results.push_back(collapsedOpResult);
1428     }
1429   }
1430   return results;
1431 }
1432 
1433 namespace {
1434 
1435 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1436 /// contracting dimensions of the loop.
1437 class FoldWithProducerReshapeOpByCollapsing
1438     : public OpRewritePattern<GenericOp> {
1439 public:
FoldWithProducerReshapeOpByCollapsing(MLIRContext * context,ControlFusionFn foldReshapes,PatternBenefit benefit=1)1440   FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1441                                         ControlFusionFn foldReshapes,
1442                                         PatternBenefit benefit = 1)
1443       : OpRewritePattern<GenericOp>(context, benefit),
1444         controlFoldingReshapes(std::move(foldReshapes)) {}
1445 
matchAndRewrite(GenericOp genericOp,PatternRewriter & rewriter) const1446   LogicalResult matchAndRewrite(GenericOp genericOp,
1447                                 PatternRewriter &rewriter) const override {
1448     for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1449       tensor::ExpandShapeOp reshapeOp =
1450           opOperand->get().getDefiningOp<tensor::ExpandShapeOp>();
1451       if (!reshapeOp)
1452         continue;
1453 
1454       SmallVector<ReassociationIndices> collapsableIterationDims =
1455           getCollapsableIterationSpaceDims(genericOp, opOperand,
1456                                            reshapeOp.getReassociationIndices());
1457       if (collapsableIterationDims.empty() ||
1458           !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
1459         continue;
1460       }
1461 
1462       Optional<SmallVector<Value>> replacements =
1463           collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
1464                                          opOperand, rewriter);
1465       if (!replacements) {
1466         return rewriter.notifyMatchFailure(
1467             genericOp, "failed to do the fusion by collapsing transformation");
1468       }
1469 
1470       rewriter.replaceOp(genericOp, *replacements);
1471       return success();
1472     }
1473     return failure();
1474   }
1475 
1476 private:
1477   ControlFusionFn controlFoldingReshapes;
1478 };
1479 } // namespace
1480 
1481 //===---------------------------------------------------------------------===//
1482 // Methods and patterns that fuse constants with linalg.generic operations.
1483 //===---------------------------------------------------------------------===//
1484 
1485 namespace {
1486 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1487 /// handle cases where the constant is not single-valued.
1488 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1489 public:
FoldScalarOrSplatConstant(MLIRContext * context,PatternBenefit benefit=1)1490   FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
1491       : OpRewritePattern<GenericOp>(context, benefit) {}
1492 
matchAndRewrite(GenericOp genericOp,PatternRewriter & rewriter) const1493   LogicalResult matchAndRewrite(GenericOp genericOp,
1494                                 PatternRewriter &rewriter) const override {
1495     if (!genericOp.hasTensorSemantics())
1496       return failure();
1497     for (OpOperand *opOperand : genericOp.getInputOperands()) {
1498       Operation *def = opOperand->get().getDefiningOp();
1499       Attribute constantAttr;
1500       auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1501         {
1502           DenseElementsAttr splatAttr;
1503           if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1504               splatAttr.isSplat() &&
1505               splatAttr.getType().getElementType().isIntOrFloat()) {
1506             constantAttr = splatAttr.getSplatValue<Attribute>();
1507             return true;
1508           }
1509         }
1510         {
1511           IntegerAttr intAttr;
1512           if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1513             constantAttr = intAttr;
1514             return true;
1515           }
1516         }
1517         {
1518           FloatAttr floatAttr;
1519           if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1520             constantAttr = floatAttr;
1521             return true;
1522           }
1523         }
1524         return false;
1525       };
1526 
1527       auto resultValue = opOperand->get().dyn_cast<OpResult>();
1528       if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1529         continue;
1530 
1531       // The operands and the indexing_maps of the fused operation the same as
1532       // the operands and indexing_maps of the generic operations with the
1533       // values at the constant index dropped.
1534       SmallVector<AffineMap> fusedIndexMaps;
1535       SmallVector<Value> fusedOperands;
1536       SmallVector<Location> fusedLocs{genericOp.getLoc()};
1537       fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1538       fusedOperands.reserve(genericOp.getNumInputs());
1539       fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
1540       for (OpOperand *inputOperand : genericOp.getInputOperands()) {
1541         if (inputOperand == opOperand)
1542           continue;
1543         Value inputValue = inputOperand->get();
1544         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1545         fusedOperands.push_back(inputValue);
1546         fusedLocs.push_back(inputValue.getLoc());
1547       }
1548       for (OpOperand *outputOperand : genericOp.getOutputOperands())
1549         fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1550 
1551       // Check if the operation shapes to loops map is computable.
1552       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1553         return rewriter.notifyMatchFailure(
1554             genericOp, "fused op loop bound computation failed");
1555       }
1556 
1557       // Create a constant scalar value from the splat constant.
1558       Value scalarConstant = rewriter.create<arith::ConstantOp>(
1559           def->getLoc(), constantAttr, constantAttr.getType());
1560 
1561       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1562       auto fusedOp = rewriter.create<GenericOp>(
1563           rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1564           /*inputs=*/fusedOperands,
1565           /*outputs=*/outputOperands,
1566           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1567           genericOp.iterator_types(),
1568           /*doc=*/nullptr,
1569           /*library_call=*/nullptr);
1570 
1571       // Map the block argument corresponding to the replaced argument with the
1572       // scalar constant.
1573       Region &region = genericOp->getRegion(0);
1574       Block &entryBlock = *region.begin();
1575       BlockAndValueMapping mapping;
1576       mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1577                   scalarConstant);
1578       Region &fusedRegion = fusedOp->getRegion(0);
1579       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1580                                  mapping);
1581       rewriter.replaceOp(genericOp, fusedOp->getResults());
1582       return success();
1583     }
1584     return failure();
1585   }
1586 };
1587 
1588 } // namespace
1589 
1590 //===---------------------------------------------------------------------===//
1591 // Miscellaneous patterns that help fusion.
1592 //===---------------------------------------------------------------------===//
1593 
1594 namespace {
1595 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
1596 /// the value of the `outs` operand is not used within the op.  This is only
1597 /// implemented for `linalg.generic` operations for now, but should hold for all
1598 /// linalg structured ops.
1599 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1600   using OpRewritePattern<GenericOp>::OpRewritePattern;
1601 
matchAndRewrite__anon750e3d171d11::RemoveOutsDependency1602   LogicalResult matchAndRewrite(GenericOp op,
1603                                 PatternRewriter &rewriter) const override {
1604     rewriter.startRootUpdate(op);
1605     bool modifiedOutput = false;
1606     Location loc = op.getLoc();
1607     for (OpOperand *opOperand : op.getOutputOperands()) {
1608       if (!op.payloadUsesValueFromOperand(opOperand)) {
1609         Value operandVal = opOperand->get();
1610         auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
1611         if (!operandType)
1612           continue;
1613 
1614         // If outs is sparse, leave it to the sparse compiler.
1615         if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
1616           continue;
1617 
1618         // If outs is already an `init_tensor` operation, nothing to do.
1619         auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
1620         if (definingOp)
1621           continue;
1622         modifiedOutput = true;
1623         SmallVector<Value> dynamicDims;
1624         for (const auto &dim : llvm::enumerate(operandType.getShape())) {
1625           if (dim.value() != ShapedType::kDynamicSize)
1626             continue;
1627           dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
1628               loc, operandVal, dim.index()));
1629         }
1630         Value initTensor = rewriter.create<InitTensorOp>(
1631             loc, dynamicDims, operandType.getShape(),
1632             operandType.getElementType());
1633         op->setOperand(opOperand->getOperandNumber(), initTensor);
1634       }
1635     }
1636     if (!modifiedOutput) {
1637       rewriter.cancelRootUpdate(op);
1638       return failure();
1639     }
1640     rewriter.finalizeRootUpdate(op);
1641     return success();
1642   }
1643 };
1644 
1645 /// Fold linalg.fill into linalg.generic
1646 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
1647   using OpRewritePattern<GenericOp>::OpRewritePattern;
1648 
matchAndRewrite__anon750e3d171d11::FoldFillWithGenericOp1649   LogicalResult matchAndRewrite(GenericOp genericOp,
1650                                 PatternRewriter &rewriter) const override {
1651     if (!genericOp.hasTensorSemantics())
1652       return failure();
1653     bool fillFound = false;
1654     Block &payload = genericOp.region().front();
1655     for (OpOperand *opOperand : genericOp.getInputOperands()) {
1656       if (!genericOp.payloadUsesValueFromOperand(opOperand))
1657         continue;
1658       FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
1659       if (!fillOp)
1660         continue;
1661       fillFound = true;
1662       payload.getArgument(opOperand->getOperandNumber())
1663           .replaceAllUsesWith(fillOp.value());
1664     }
1665     return success(fillFound);
1666   }
1667 };
1668 } // namespace
1669 
populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet & patterns,const ControlFusionFn & controlFoldingReshapes)1670 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
1671     RewritePatternSet &patterns,
1672     const ControlFusionFn &controlFoldingReshapes) {
1673   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
1674                                                     controlFoldingReshapes);
1675   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
1676                                                      controlFoldingReshapes);
1677 }
1678 
populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet & patterns,const ControlFusionFn & controlFoldingReshapes)1679 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
1680     RewritePatternSet &patterns,
1681     const ControlFusionFn &controlFoldingReshapes) {
1682   patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
1683                                                       controlFoldingReshapes);
1684 }
1685 
populateElementwiseOpsFusionPatterns(RewritePatternSet & patterns,const ControlFusionFn & controlElementwiseOpsFusion)1686 void mlir::linalg::populateElementwiseOpsFusionPatterns(
1687     RewritePatternSet &patterns,
1688     const ControlFusionFn &controlElementwiseOpsFusion) {
1689   auto *context = patterns.getContext();
1690   patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
1691   patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
1692                RemoveOutsDependency>(context);
1693 }
1694 
1695 //===---------------------------------------------------------------------===//
1696 // Passes
1697 //===---------------------------------------------------------------------===//
1698 
1699 namespace {
1700 
1701 /// Pass that fuses generic ops on tensors. Used only for testing.
1702 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
1703 // patterns added here heavily depends on the cost function used. Having an
1704 // opinionated pass of this form is not recommended. Deprecate this pass in
1705 // favor of test passes that check the functionality of each of the patterns
1706 // added here individually.
1707 struct LinalgElementwiseOpFusionPass
1708     : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
runOnOperation__anon750e3d171e11::LinalgElementwiseOpFusionPass1709   void runOnOperation() override {
1710     Operation *op = getOperation();
1711     MLIRContext *context = op->getContext();
1712     RewritePatternSet patterns(context);
1713 
1714     // Add folding with reshape by expansion patterns.
1715     ControlFusionFn defaultControlFn = [](const OpResult &producer,
1716                                           const OpOperand &consumer) {
1717       return producer.hasOneUse();
1718     };
1719 
1720     // Add elementwise op fusion patterns.
1721     populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
1722     populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
1723 
1724     // General canonicalization patterns.
1725     AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1726     GenericOp::getCanonicalizationPatterns(patterns, context);
1727     tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1728     tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1729     context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1730         patterns);
1731 
1732     // Add constant folding patterns.
1733     populateConstantFoldLinalgOperations(patterns, defaultControlFn);
1734 
1735     // Use TopDownTraversal for compile time reasons
1736     GreedyRewriteConfig grc;
1737     grc.useTopDownTraversal = true;
1738     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
1739                                        grc);
1740   }
1741 };
1742 
1743 } // namespace
1744 
createLinalgElementwiseOpFusionPass()1745 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
1746   return std::make_unique<LinalgElementwiseOpFusionPass>();
1747 }
1748