1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include <utility>
13
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27
28 using namespace mlir;
29 using namespace mlir::linalg;
30
31 //===---------------------------------------------------------------------===//
32 // Methods and patterns that fuse elementwise `linalg.generic` operations.
33 //===---------------------------------------------------------------------===//
34
35 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
36 /// the `producer` to use in the fused operation given the indexing map of the
37 /// result of the producer in the consumer.
getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand * producerOpOperand,AffineMap producerResultIndexMap,AffineMap fusedConsumerArgIndexMap)38 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
39 OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
40 AffineMap fusedConsumerArgIndexMap) {
41 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
42 // from consumer loop -> consumer arg tensor index/producer result tensor
43 // index. The fused loop is same as the consumer loop. For each producer arg
44 // the indexing map to be computed is a map from consumer loop -> producer
45 // arg tensor index.
46 // producerResultIndexMap is a map from producer loop -> tensor index.
47 // Compute the inverse to get map from tensor index -> producer loop.
48 // The inverse is a map from producer result tensor index -> producer loop.
49 AffineMap invProducerResultIndexMap =
50 inversePermutation(producerResultIndexMap);
51 assert(invProducerResultIndexMap &&
52 "expected producer result indexing map to be invertible");
53
54 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
55 // argMap is a map from producer loop -> producer arg tensor index.
56 AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
57
58 // Compose argMap with invProducerResultIndexMap to get a map from
59 // producer result tensor index -> producer arg tensor index.
60 AffineMap t1 = argMap.compose(invProducerResultIndexMap);
61
62 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
63 // consumer loop/ fused loop -> producer arg tensor index.
64 return t1.compose(fusedConsumerArgIndexMap);
65 }
66
67 /// Conditions for elementwise fusion of generic operations.
areElementwiseOpsFusable(GenericOp producer,GenericOp consumer,OpOperand * consumerOpOperand)68 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
69 OpOperand *consumerOpOperand) {
70 // Producer and consumer must have tensor semantics.
71 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
72 return false;
73
74 // Verify that
75 // - the producer has all "parallel" iterator type.
76 if (producer.getNumParallelLoops() != producer.getNumLoops())
77 return false;
78
79 // Only allow fusing the producer of an input operand for now.
80 // TODO: allow fusing the producer of an output operand.
81 if (!consumer.isInputTensor(consumerOpOperand))
82 return false;
83
84 // Get the consumer index map. The number of results of the consumer index
85 // map must match the number of loops of the producer.
86 AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
87 if (consumerIndexMap.getNumResults() != producer.getNumLoops())
88 return false;
89
90 // Currently support only operations with single result.
91 if (producer.getNumOutputs() != 1)
92 return false;
93
94 // Finally the index_map for the result must be invertible. For now just
95 // verify it is a permutation.
96 AffineMap producerResultIndexMap =
97 producer.getTiedIndexingMap(producer.getOutputOperand(0));
98 if (!producerResultIndexMap.isPermutation())
99 return false;
100
101 // Ensure that the fusion does not remove size information required to
102 // get the loop bounds. For non-reduction generics, this is trivially the
103 // case due to the output operand. For reductions, we need to check that after
104 // the fusion, each loop dimension has at least one input that defines it.
105 if ((consumer.getNumReductionLoops())) {
106 BitVector coveredDims(consumer.getNumLoops(), false);
107
108 auto addToCoveredDims = [&](AffineMap map) {
109 for (auto result : map.getResults())
110 if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
111 coveredDims[dimExpr.getPosition()] = true;
112 };
113
114 for (auto pair :
115 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
116 Value operand = std::get<0>(pair);
117 if (operand == consumerOpOperand->get())
118 continue;
119 AffineMap operandMap = std::get<1>(pair);
120 addToCoveredDims(operandMap);
121 }
122
123 for (OpOperand *operand : producer.getInputOperands()) {
124 AffineMap newIndexingMap =
125 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
126 operand, producerResultIndexMap, consumerIndexMap);
127 addToCoveredDims(newIndexingMap);
128 }
129 if (!coveredDims.all())
130 return false;
131 }
132
133 return true;
134 }
135
136 /// Generate the region of the fused tensor operation. The region of the fused
137 /// op must be empty.
138 static void
generateFusedElementwiseOpRegion(PatternRewriter & rewriter,GenericOp fusedOp,AffineMap consumerToProducerLoopsMap,OpOperand * consumerOpOperand,unsigned nloops)139 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
140 AffineMap consumerToProducerLoopsMap,
141 OpOperand *consumerOpOperand,
142 unsigned nloops) {
143 auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
144 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
145 // Build the region of the fused op.
146 Block &producerBlock = producer->getRegion(0).front();
147 Block &consumerBlock = consumer->getRegion(0).front();
148 Block *fusedBlock = new Block();
149 fusedOp.region().push_back(fusedBlock);
150 BlockAndValueMapping mapper;
151 OpBuilder::InsertionGuard guard(rewriter);
152 rewriter.setInsertionPointToStart(fusedBlock);
153
154 // 2. Add an index operation for every fused loop dimension and use the
155 // `consumerToProducerLoopsMap` to map the producer indices.
156 if (producer.hasIndexSemantics()) {
157 // Add an index operation for every fused loop dimension.
158 unsigned numFusedOpLoops =
159 std::max(producer.getNumLoops(), consumer.getNumLoops());
160 SmallVector<Value> fusedIndices;
161 fusedIndices.reserve(numFusedOpLoops);
162 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
163 std::back_inserter(fusedIndices), [&](uint64_t dim) {
164 return rewriter.create<IndexOp>(producer.getLoc(), dim);
165 });
166 for (IndexOp indexOp :
167 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
168 Value newIndex = rewriter.create<mlir::AffineApplyOp>(
169 producer.getLoc(),
170 consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
171 mapper.map(indexOp.getResult(), newIndex);
172 }
173 }
174 // TODO: allow fusing the producer of an output operand.
175 assert(consumer.isInputTensor(consumerOpOperand) &&
176 "expected producer of input operand");
177 // 3. Consumer input operands up to consumerIdx (exclusive).
178 for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
179 consumerOpOperand->getOperandNumber())) // input assumption.
180 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
181
182 // Replacing consumerIdx requires getting the cloned, yielded, value from
183 // the (cloned) producer block. This happens in step 9.
184
185 // 4. Splice in producer's input operands.
186 for (BlockArgument bbArg :
187 producerBlock.getArguments().take_front(producer.getNumInputs()))
188 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
189
190 // 4.b. Producer output operand/map that is fused needs to be mapped to the
191 // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
192 assert(producer->getNumResults() == 1 && "expected single result producer");
193 if (producer.isInitTensor(producer.getOutputOperand(0))) {
194 BlockArgument bbArg = producerBlock.getArguments()
195 .drop_front(producer.getNumInputs())
196 // TODO: bbArg index of
197 .front();
198 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
199 }
200 // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
201 for (BlockArgument bbArg :
202 consumerBlock.getArguments()
203 .take_front(consumer.getNumInputs())
204 .drop_front(consumerOpOperand->getOperandNumber() + 1))
205 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
206 // 6. All of consumer's output operands.
207 for (BlockArgument bbArg :
208 consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
209 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
210 // 7. All of producer's output operands except the one fused.
211 // TODO: allow fusion of multi-result producers.
212 assert(producer->getNumResults() == 1 && "expected single result producer");
213
214 // 8. Clone all producer operations except for the yield and index operations
215 // to the fused operation.
216 for (auto &op : producerBlock.without_terminator()) {
217 if (!isa<IndexOp>(op))
218 rewriter.clone(op, mapper);
219 }
220 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
221 // forward the yield operand.
222 auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
223 // TODO: allow fusion of multi-result producers.
224 assert(producer->getNumResults() == 1 && "expected single result producer");
225 unsigned producerResultNumber = 0;
226 Value replacement =
227 mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
228 // Sanity checks, if replacement is not already in the mapper then it must be
229 // produced outside.
230 if (replacement == yieldOp.getOperand(producerResultNumber)) {
231 if (auto bb = replacement.dyn_cast<BlockArgument>())
232 assert(bb.getOwner() != &producerBlock &&
233 "yielded block argument must have been mapped");
234 else
235 assert(!producer->isAncestor(replacement.getDefiningOp()) &&
236 "yielded value must have been mapped");
237 }
238 mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
239 replacement);
240 // 10. Clone operations from the consumer to the fused op.
241 for (auto &op : consumerBlock.getOperations())
242 rewriter.clone(op, mapper);
243
244 // Sanity checks.
245 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
246 "Ill-formed GenericOp region");
247 }
248
249 static Optional<SmallVector<Value>>
fuseElementwiseOpsImpl(GenericOp producer,OpOperand * consumerOpOperand,const ControlFusionFn & controlFn,PatternRewriter & rewriter)250 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
251 const ControlFusionFn &controlFn,
252 PatternRewriter &rewriter) {
253 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
254 if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
255 !controlFn(producer->getResult(0), *consumerOpOperand))
256 return llvm::None;
257
258 // TODO: allow fusing the producer of an output operand.
259 assert(consumer.isInputTensor(consumerOpOperand) &&
260 "expected producer of input operand");
261
262 // Compute the fused operands list and indexing maps.
263 SmallVector<Value> fusedOperands;
264 SmallVector<AffineMap> fusedIndexMaps;
265 fusedOperands.reserve(producer->getNumOperands() +
266 consumer->getNumOperands());
267 fusedIndexMaps.reserve(producer->getNumOperands() +
268 consumer->getNumOperands());
269 // In the following, numbering matches that of `generateFusedTensorOpRegion`.
270 // 3. Consumer input operands/maps up to consumerIdx (exclusive).
271 SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
272 SmallVector<OpOperand *>::iterator it =
273 llvm::find(consumerInputs, consumerOpOperand);
274 assert(it != consumerInputs.end() && "expected to find the consumer operand");
275 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
276 fusedOperands.push_back(opOperand->get());
277 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
278 }
279 // 4. Splice in producer's input operands/maps.
280 assert(producer->getNumResults() == 1 && "expected single result producer");
281 AffineMap producerResultIndexMap =
282 producer.getTiedIndexingMap(producer.getOutputOperand(0));
283 for (OpOperand *opOperand : producer.getInputOperands()) {
284 fusedOperands.push_back(opOperand->get());
285 // Compute indexing maps for the producer args in the fused operation.
286 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
287 opOperand, producerResultIndexMap,
288 consumer.getTiedIndexingMap(consumerOpOperand));
289 fusedIndexMaps.push_back(map);
290 }
291 // 4.b. Producer output operand/map that is fused needs to be passed if it is
292 // an "initTensor" (i.e. its value is actually read).
293 assert(producer->getNumResults() == 1 && "expected single result producer");
294 if (producer.isInitTensor(producer.getOutputOperand(0))) {
295 fusedOperands.push_back(producer.getOutputOperand(0)->get());
296 // Compute indexing maps for the producer args in the fused operation.
297 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
298 producer.getOutputOperand(0), producerResultIndexMap,
299 consumer.getTiedIndexingMap(consumerOpOperand));
300 fusedIndexMaps.push_back(map);
301 }
302 // 5. Remaining consumer's input operands/maps (drop past index
303 // `consumerIdx`).
304 for (OpOperand *opOperand :
305 llvm::make_range(std::next(it), consumerInputs.end())) {
306 fusedOperands.push_back(opOperand->get());
307 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
308 }
309 // 6. All of consumer's output operands (skip operands: added by the builder).
310 for (OpOperand *opOperand : consumer.getOutputOperands())
311 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
312 // 7. All of producer's output operands/maps except the one fused.
313 // TODO: allow fusion of multi-result producers.
314 assert(producer->getNumResults() == 1 && "expected single result producer");
315
316 // Generate the fused op.
317 SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
318 auto fusedOp = rewriter.create<GenericOp>(
319 consumer.getLoc(), consumer->getResultTypes(),
320 /*inputs=*/fusedOperands,
321 // TODO: handle outputs.
322 consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
323 consumer.iterator_types(),
324 /*doc=*/nullptr,
325 /*library_call=*/nullptr);
326 if (!fusedOp.getShapesToLoopsMap()) {
327 // Fused op has invalid indexing maps. Typically this means something is off
328 // in the input, but going ahead here would result in verification errors.
329 // So cleanup and abort.
330 rewriter.eraseOp(fusedOp);
331 return llvm::None;
332 }
333
334 // Construct an AffineMap from consumer loops to producer loops.
335 // consumer loop -> tensor index
336 AffineMap consumerResultIndexMap =
337 consumer.getTiedIndexingMap(consumerOpOperand);
338 // tensor index -> producer loop
339 AffineMap invProducerResultIndexMap =
340 inversePermutation(producerResultIndexMap);
341 assert(invProducerResultIndexMap &&
342 "expected producer result indexig map to be invertible");
343 // consumer loop -> producer loop
344 AffineMap consumerToProducerLoopsMap =
345 invProducerResultIndexMap.compose(consumerResultIndexMap);
346
347 generateFusedElementwiseOpRegion(rewriter, fusedOp,
348 consumerToProducerLoopsMap,
349 consumerOpOperand, consumer.getNumLoops());
350 return SmallVector<Value>(fusedOp->getResults());
351 }
352
353 static Optional<SmallVector<Value>>
fuseElementwiseOps(PatternRewriter & rewriter,OpOperand * consumerOpOperand,GenericOp producer,const ControlFusionFn & controlFn)354 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
355 GenericOp producer, const ControlFusionFn &controlFn) {
356 if (producer->getNumResults() != 1)
357 return llvm::None;
358
359 return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
360 rewriter);
361 }
362
363 namespace {
364 /// Patterns to fuse a generic op, with the producer of its operands.
365 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
366 public:
FuseElementwiseOps(MLIRContext * context,ControlFusionFn fun,PatternBenefit benefit=1)367 FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
368 PatternBenefit benefit = 1)
369 : OpRewritePattern<GenericOp>(context, benefit),
370 controlFn(std::move(fun)) {}
371
matchAndRewrite(GenericOp genericOp,PatternRewriter & rewriter) const372 LogicalResult matchAndRewrite(GenericOp genericOp,
373 PatternRewriter &rewriter) const override {
374 // Find the first operand that is defined by another generic op on tensors.
375 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
376 auto producer =
377 dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
378 if (!producer || !producer.hasTensorSemantics())
379 continue;
380 Optional<SmallVector<Value>> fusedOpResults =
381 fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
382 if (fusedOpResults) {
383 rewriter.replaceOp(genericOp, *fusedOpResults);
384 return success();
385 }
386 }
387 return failure();
388 }
389
390 private:
391 ControlFusionFn controlFn;
392 };
393 } // namespace
394
395 //===---------------------------------------------------------------------===//
396 // Methods and patterns that fuse reshape ops with elementwise operations by
397 // expanding the dimensionality of the elementwise operations.
398 //===---------------------------------------------------------------------===//
399
400 /// Conditions for folding a generic operation with a reshape op by expanding
401 /// the iteration space dimensionality for tensor operations. These are
402 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the
403 /// following fusion pattern.
404 ///
405 /// Consider
406 ///
407 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
408 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
409 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
410 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
411 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
412 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
413 ///
414 /// The reshape can be folded into the `genericOp` if its loop dimensionality
415 /// is increased to match the result (operand) of the tensor.expand_shape.
416 /// The indexing_map of the fused tensor in the `genericOp` and the
417 /// reassociation map helps compute the indexing maps of the modified op.
418 /// For the above example, based on the reassociation map it
419 /// can be concluded that
420 ///
421 /// - The loop used to access the first dimension of the fused tensor is split
422 /// into two.
423 /// - The loop used to access the second dimension of the fused tensor is kept
424 /// as is.
425 /// - The loop used to access the third dimension of the fused tensor is split
426 /// into three.
427 ///
428 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
429 /// op, then
430 ///
431 /// d0 -> e0, e1
432 /// d1 -> e2, e3, e4
433 /// d2 -> e5
434 ///
435 /// substituting this, the generic op can be rewritten as
436 ///
437 /// %d = linalg.generic ins(%0, %1 : )
438 /// indexing_maps =
439 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
440 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
441 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
442 ///
443 /// Since operands to the linalg generic are now 5D, reshapes can be introduced
444 /// to make it consistent
445 ///
446 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
447 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
448 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
449 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
450 ///
451 /// The added reshapes are again expanding patterns, so they will get fused
452 /// with its producers if possible.
isFusableWithReshapeByDimExpansion(GenericOp genericOp,OpOperand * fusableOpOperand)453 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
454 OpOperand *fusableOpOperand) {
455 // Is fusable only if:
456 // - All the indexing maps for operands and results are projected
457 // permutations.
458 // - The fused tensor is not a scalar.
459 // - All the loops are parallel loops.
460 return genericOp.hasTensorSemantics() &&
461 llvm::all_of(genericOp.indexing_maps().getValue(),
462 [](Attribute attr) {
463 return attr.cast<AffineMapAttr>()
464 .getValue()
465 .isProjectedPermutation();
466 }) &&
467 genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
468 llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
469 return attr.cast<StringAttr>().getValue() ==
470 getParallelIteratorTypeName();
471 });
472 }
473
474 namespace {
475 /// Information needed to expand a generic operation to fold the reshape with
476 /// it.
477 class ExpansionInfo {
478 public:
479 // Computes the mapping from original dimensions of the op to the dimensions
480 // of the expanded op given the `indexingMap` of the fused operand/result of
481 // the generic op, the `reassocationMaps` of the reshape op and the shape of
482 // the expanded op.
483 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
484 ArrayRef<AffineMap> reassociationMaps,
485 ArrayRef<int64_t> expandedShape,
486 ArrayRef<int64_t> collapsedShape,
487 PatternRewriter &rewriter);
getOrigOpNumDims() const488 unsigned getOrigOpNumDims() const { return reassociation.size(); }
getExpandedOpNumDims() const489 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
getExpandedDims(unsigned i) const490 ReassociationIndicesRef getExpandedDims(unsigned i) const {
491 return reassociation[i];
492 }
getExpandedShapeOfDim(unsigned i) const493 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
494 return expandedShapeMap[i];
495 }
getOriginalShape() const496 ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
497
498 private:
499 /// Reassociation from the dimensions in the original operation to the
500 /// dimension of the expanded operation.
501 SmallVector<ReassociationIndices> reassociation;
502 /// Mapping from extent of loops in the original operation, to the extent of
503 /// loops in the expanded operation.
504 SmallVector<SmallVector<int64_t>> expandedShapeMap;
505 /// Extent of the loop in the original operation.
506 SmallVector<int64_t> originalLoopExtent;
507 unsigned expandedOpNumDims;
508 };
509 } // namespace
510
compute(LinalgOp linalgOp,OpOperand * fusableOpOperand,ArrayRef<AffineMap> reassociationMaps,ArrayRef<int64_t> expandedShape,ArrayRef<int64_t> collapsedShape,PatternRewriter & rewriter)511 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
512 OpOperand *fusableOpOperand,
513 ArrayRef<AffineMap> reassociationMaps,
514 ArrayRef<int64_t> expandedShape,
515 ArrayRef<int64_t> collapsedShape,
516 PatternRewriter &rewriter) {
517 if (reassociationMaps.empty())
518 return failure();
519 AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
520
521 SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
522 originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
523
524 reassociation.clear();
525 expandedShapeMap.clear();
526 // Compute the number of dimension in the expanded op that correspond to each
527 // dimension of the original op.
528 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
529 expandedShapeMap.resize(fusedIndexMap.getNumDims());
530 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
531 unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
532 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
533 numExpandedDims[pos] = foldedDims.getNumResults();
534 ArrayRef<int64_t> shape =
535 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
536 expandedShapeMap[pos].assign(shape.begin(), shape.end());
537 }
538 // The remaining dimensions remain the same.
539 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
540 if (expandedShapeMap[i].empty())
541 expandedShapeMap[i] = {originalLoopExtent[i]};
542
543 // Compute reassociation map from the original op to the expanded op.
544 unsigned sum = 0;
545 reassociation.reserve(fusedIndexMap.getNumDims());
546 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
547 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
548 reassociation.emplace_back(seq.begin(), seq.end());
549 sum += numFoldedDim.value();
550 }
551 expandedOpNumDims = sum;
552 return success();
553 }
554
555 /// Epanding the body of a linalg operation requires adaptations of the accessed
556 /// loop indices. Specifically, access of indices in the original operation need
557 /// to be replaced with linearizations of indices in the expanded op. That
558 /// requires the shape of the expanded dimensions to be static (at least all but
559 /// the most significant). For now check that these are all statically sized.
560 /// Note that this could be extended to handle dynamic case, but the
561 /// implementation below uses `affine.apply` which seems to have issues when the
562 /// shapes are not static.
isGenericOpExpandable(GenericOp genericOp,const ExpansionInfo & expansionInfo,PatternRewriter & rewriter)563 static LogicalResult isGenericOpExpandable(GenericOp genericOp,
564 const ExpansionInfo &expansionInfo,
565 PatternRewriter &rewriter) {
566 if (!genericOp.hasIndexSemantics())
567 return success();
568 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
569 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
570 if (expandedShape.size() == 1)
571 continue;
572 for (int64_t shape : expandedShape.drop_front()) {
573 if (ShapedType::isDynamic(shape)) {
574 return rewriter.notifyMatchFailure(
575 genericOp, "cannot expand due to index semantics and dynamic dims");
576 }
577 }
578 }
579 return success();
580 }
581
582 /// Return the indexing map to use in the expanded op for a given the
583 /// `indexingMap` of the original operation.
584 static AffineMap
getIndexingMapInExpandedOp(OpBuilder & builder,AffineMap indexingMap,const ExpansionInfo & expansionInfo)585 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
586 const ExpansionInfo &expansionInfo) {
587 SmallVector<AffineExpr> newExprs;
588 for (AffineExpr expr : indexingMap.getResults()) {
589 unsigned pos = expr.cast<AffineDimExpr>().getPosition();
590 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
591 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
592 return builder.getAffineDimExpr(static_cast<unsigned>(v));
593 }));
594 newExprs.append(expandedExprs.begin(), expandedExprs.end());
595 }
596 return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
597 indexingMap.getNumSymbols(), newExprs,
598 builder.getContext());
599 }
600
601 /// Return the type of the operand/result to use in the expanded op given the
602 /// type in the original op.
getExpandedType(RankedTensorType originalType,AffineMap indexingMap,const ExpansionInfo & expansionInfo)603 static RankedTensorType getExpandedType(RankedTensorType originalType,
604 AffineMap indexingMap,
605 const ExpansionInfo &expansionInfo) {
606 SmallVector<int64_t> expandedShape;
607 for (AffineExpr expr : indexingMap.getResults()) {
608 unsigned dim = expr.cast<AffineDimExpr>().getPosition();
609 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
610 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
611 }
612 return RankedTensorType::get(expandedShape, originalType.getElementType());
613 }
614
615 /// Returns the reassociation maps to use in the `tensor.expand_shape`
616 /// operation to convert the operands of the original operation to operands of
617 /// the expanded operation. The same method is used to compute the
618 /// `tensor.collapse_shape` used to collapse the result of the expanded
619 /// op to get the value that can replace all uses of the results of the original
620 /// op.
621 static SmallVector<ReassociationIndices>
getReassociationForExpansion(AffineMap indexingMap,const ExpansionInfo & expansionInfo)622 getReassociationForExpansion(AffineMap indexingMap,
623 const ExpansionInfo &expansionInfo) {
624 SmallVector<ReassociationIndices> reassociation;
625 unsigned numReshapeDims = 0;
626 for (AffineExpr expr : indexingMap.getResults()) {
627 unsigned dim = expr.cast<AffineDimExpr>().getPosition();
628 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
629 SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
630 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
631 reassociation.emplace_back(std::move(indices));
632 numReshapeDims += numExpandedDims;
633 }
634 return reassociation;
635 }
636
637 /// Update the body of an expanded linalg operation having index semantics. The
638 /// indices of the original operation need to be recovered by linearizing the
639 /// indices of the correspoding dimensions of the expanded operation. For now it
640 /// is assumed that the shapes of the expanded operation needed for
641 /// linearization are static.
updateExpandedGenericOpRegion(PatternRewriter & rewriter,Location loc,Region & fusedRegion,const ExpansionInfo & expansionInfo)642 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
643 Location loc, Region &fusedRegion,
644 const ExpansionInfo &expansionInfo) {
645 // Replace the original indices by the linearization of the expanded indices.
646 for (IndexOp indexOp :
647 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
648 ArrayRef<int64_t> expandedDims =
649 expansionInfo.getExpandedDims(indexOp.dim());
650 assert(!expandedDims.empty() && "expected valid expansion info");
651
652 // Skip index operations that are not affected by the expansion.
653 if (expandedDims.size() == 1 &&
654 expandedDims.front() == (int64_t)indexOp.dim())
655 continue;
656
657 // Linearize the expanded indices of the original index dimension.
658 OpBuilder::InsertionGuard guard(rewriter);
659 rewriter.setInsertionPointAfter(indexOp);
660 ArrayRef<int64_t> expandedDimsShape =
661 expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
662 SmallVector<Value> expandedIndices;
663 expandedIndices.reserve(expandedDims.size() - 1);
664 llvm::transform(
665 expandedDims.drop_front(), std::back_inserter(expandedIndices),
666 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
667 Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
668 for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
669 assert(!ShapedType::isDynamic(std::get<0>(it)));
670 AffineExpr idx, acc;
671 bindDims(rewriter.getContext(), idx, acc);
672 newIndex = rewriter.create<AffineApplyOp>(
673 indexOp.getLoc(), idx + acc * std::get<0>(it),
674 ValueRange{std::get<1>(it), newIndex});
675 }
676 rewriter.replaceOp(indexOp, newIndex);
677 }
678 }
679
680 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
681 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
682 /// that those conditions have been satisfied.
683 static Optional<SmallVector<Value>>
fuseWithReshapeByExpansion(GenericOp genericOp,Operation * reshapeOp,OpOperand * fusableOpOperand,PatternRewriter & rewriter)684 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
685 OpOperand *fusableOpOperand,
686 PatternRewriter &rewriter) {
687 assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
688 "preconditions for fuse operation failed");
689 // Check if reshape is expanding or collapsing.
690 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
691 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
692 bool isExpanding = (expandingReshapeOp != nullptr);
693 RankedTensorType expandedType = isExpanding
694 ? expandingReshapeOp.getResultType()
695 : collapsingReshapeOp.getSrcType();
696 RankedTensorType collapsedType = isExpanding
697 ? expandingReshapeOp.getSrcType()
698 : collapsingReshapeOp.getResultType();
699
700 ExpansionInfo expansionInfo;
701 if (failed(expansionInfo.compute(
702 genericOp, fusableOpOperand,
703 isExpanding ? expandingReshapeOp.getReassociationMaps()
704 : collapsingReshapeOp.getReassociationMaps(),
705 expandedType.getShape(), collapsedType.getShape(), rewriter)))
706 return llvm::None;
707
708 if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
709 return llvm::None;
710
711 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
712 llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap m) {
713 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
714 }));
715
716 SmallVector<Value> expandedOpOperands;
717 expandedOpOperands.reserve(genericOp.getNumInputs());
718 for (OpOperand *opOperand : genericOp.getInputOperands()) {
719 if (opOperand == fusableOpOperand) {
720 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
721 : collapsingReshapeOp.getSrc());
722 continue;
723 }
724 if (genericOp.isInputTensor(opOperand)) {
725 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
726 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
727 RankedTensorType expandedOperandType =
728 getExpandedType(opOperandType, indexingMap, expansionInfo);
729 if (expandedOperandType != opOperand->get().getType()) {
730 // Reshape the operand to get the right type.
731 SmallVector<ReassociationIndices> reassociation =
732 getReassociationForExpansion(indexingMap, expansionInfo);
733 if (failed(reshapeLikeShapesAreCompatible(
734 [&](const Twine &msg) {
735 return rewriter.notifyMatchFailure(genericOp, msg);
736 },
737 opOperandType.getShape(), expandedOperandType.getShape(),
738 reassociation,
739 /*isExpandingReshape=*/true)))
740 return llvm::None;
741 expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
742 genericOp.getLoc(), expandedOperandType, opOperand->get(),
743 reassociation));
744 continue;
745 }
746 }
747 expandedOpOperands.push_back(opOperand->get());
748 }
749
750 Location loc = genericOp.getLoc();
751 SmallVector<Value> outputs;
752 for (OpOperand *opOperand : genericOp.getOutputOperands()) {
753 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
754 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
755 RankedTensorType expandedOutputType =
756 getExpandedType(opOperandType, indexingMap, expansionInfo);
757 if (expandedOutputType != opOperand->get().getType()) {
758 SmallVector<ReassociationIndices> reassociation =
759 getReassociationForExpansion(indexingMap, expansionInfo);
760 if (failed(reshapeLikeShapesAreCompatible(
761 [&](const Twine &msg) {
762 return rewriter.notifyMatchFailure(genericOp, msg);
763 },
764 opOperandType.getShape(), expandedOutputType.getShape(),
765 reassociation,
766 /*isExpandingReshape=*/true)))
767 return llvm::None;
768 outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
769 genericOp.getLoc(), expandedOutputType, opOperand->get(),
770 reassociation));
771 }
772 }
773
774 // The iterator types of the expanded op are all parallel.
775 SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
776 getParallelIteratorTypeName());
777
778 TypeRange resultTypes = ValueRange(outputs).getTypes();
779 auto fusedOp =
780 rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
781 /*inputs=*/expandedOpOperands, outputs,
782 expandedOpIndexingMaps, iteratorTypes);
783 Region &fusedRegion = fusedOp->getRegion(0);
784 Region &originalRegion = genericOp->getRegion(0);
785 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
786
787 // Update the index accesses after the expansion.
788 updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
789
790 // Reshape the result values to their original shape if this is a collapsing
791 // reshape folded into its consumer.
792 SmallVector<Value> resultVals;
793 for (OpResult opResult : genericOp->getOpResults()) {
794 int64_t resultNumber = opResult.getResultNumber();
795 if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
796 SmallVector<ReassociationIndices> reassociation =
797 getReassociationForExpansion(
798 genericOp.getTiedIndexingMap(
799 genericOp.getOutputOperand(resultNumber)),
800 expansionInfo);
801 resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
802 genericOp.getLoc(), opResult.getType(),
803 fusedOp->getResult(resultNumber), reassociation));
804 } else {
805 resultVals.push_back(fusedOp->getResult(resultNumber));
806 }
807 }
808 // Assuming a single result.
809 return resultVals;
810 }
811
812 namespace {
813
814 /// Pattern to fuse a tensor.collapse_shape op with its consumer generic op,
815 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
816 /// in the consumer is expanded.
817 class FoldWithProducerReshapeOpByExpansion
818 : public OpRewritePattern<GenericOp> {
819 public:
FoldWithProducerReshapeOpByExpansion(MLIRContext * context,ControlFusionFn foldReshapes,PatternBenefit benefit=1)820 FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
821 ControlFusionFn foldReshapes,
822 PatternBenefit benefit = 1)
823 : OpRewritePattern<GenericOp>(context, benefit),
824 controlFoldingReshapes(std::move(foldReshapes)) {}
825
matchAndRewrite(GenericOp genericOp,PatternRewriter & rewriter) const826 LogicalResult matchAndRewrite(GenericOp genericOp,
827 PatternRewriter &rewriter) const override {
828 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
829 tensor::CollapseShapeOp reshapeOp =
830 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
831 if (!reshapeOp)
832 continue;
833 // Fold only if
834 // - The tensor reshape op is folding.
835 // - All constraints of fusing with reshape by expansion are met.
836 if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
837 (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
838 continue;
839
840 Optional<SmallVector<Value>> replacementValues =
841 fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
842 if (!replacementValues)
843 return failure();
844 rewriter.replaceOp(genericOp, *replacementValues);
845 return success();
846 }
847 return failure();
848 }
849
850 private:
851 ControlFusionFn controlFoldingReshapes;
852 };
853
854 /// Pattern to fold a tensor.expand_shape op with its producer generic op
855 /// by expanding the dimensionality of the loop in the producer op.
856 struct FoldReshapeWithGenericOpByExpansion
857 : public OpRewritePattern<tensor::ExpandShapeOp> {
858
FoldReshapeWithGenericOpByExpansion__anon750e3d170c11::FoldReshapeWithGenericOpByExpansion859 FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
860 ControlFusionFn foldReshapes,
861 PatternBenefit benefit = 1)
862 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
863 controlFoldingReshapes(std::move(foldReshapes)) {}
864
matchAndRewrite__anon750e3d170c11::FoldReshapeWithGenericOpByExpansion865 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
866 PatternRewriter &rewriter) const override {
867 // Fold only if all constraints of fusing with reshape by expansion are met.
868 GenericOp producer = reshapeOp.getSrc().getDefiningOp<GenericOp>();
869 if (!producer || producer.getNumOutputs() != 1 ||
870 !isFusableWithReshapeByDimExpansion(producer,
871 producer.getOutputOperand(0)) ||
872 !controlFoldingReshapes(producer->getResult(0),
873 reshapeOp->getOpOperand(0)))
874 return failure();
875 Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
876 producer, reshapeOp, producer.getOutputOperand(0), rewriter);
877 if (!replacementValues)
878 return failure();
879 rewriter.replaceOp(reshapeOp, *replacementValues);
880 return success();
881 }
882
883 private:
884 ControlFusionFn controlFoldingReshapes;
885 };
886 } // namespace
887
888 //===---------------------------------------------------------------------===//
889 // Methods and patterns to fuse reshape with linalg.generic operations by
890 // contraction of dimensions.
891 //===---------------------------------------------------------------------===//
892
893 /// For a given list of indices in the range of the `indexingMap` that are
894 /// folded, return the indices of the corresponding domain. Return `llvm::None`
895 /// on failure. Ensures that all the elements of the returned reassociation are
896 /// distinct.
897 static ReassociationIndices
getDomainReassociation(AffineMap indexingMap,ReassociationIndicesRef rangeReassociation)898 getDomainReassociation(AffineMap indexingMap,
899 ReassociationIndicesRef rangeReassociation) {
900 assert(indexingMap.isProjectedPermutation() &&
901 "expected projected permutation");
902
903 ReassociationIndices domainReassociation = llvm::to_vector<4>(
904 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
905 return indexingMap.getResults()[pos]
906 .cast<AffineDimExpr>()
907 .getPosition();
908 }));
909 // The projected permutation semantics ensures that there is no repetition of
910 // the domain indices.
911 return domainReassociation;
912 }
913
914 /// For a given `dimSequence`, check if the sequence is conserved in the
915 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
916 /// Non-existence of the sequence returns true as well.
isDimSequencePreserved(AffineMap indexingMap,ReassociationIndicesRef dimSequence)917 static bool isDimSequencePreserved(AffineMap indexingMap,
918 ReassociationIndicesRef dimSequence) {
919 assert(!dimSequence.empty() &&
920 "expected non-empty list for dimension sequence");
921 assert(indexingMap.isProjectedPermutation() &&
922 "expected indexing map to be projected permutation");
923
924 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
925 sequenceElements.insert(dimSequence.begin(), dimSequence.end());
926
927 unsigned dimSequenceStart = dimSequence[0];
928 for (const auto &expr : enumerate(indexingMap.getResults())) {
929 unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition();
930 // 1. Check if this start of the sequence.
931 if (dimInMapStart == dimSequenceStart) {
932 if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
933 return false;
934 // 1a. Check if sequence is preserved.
935 for (const auto &dimInSequence : enumerate(dimSequence)) {
936 unsigned dimInMap =
937 indexingMap.getResult(expr.index() + dimInSequence.index())
938 .cast<AffineDimExpr>()
939 .getPosition();
940 if (dimInMap != dimInSequence.value())
941 return false;
942 }
943 // Found the sequence. Projected permutation
944 // enforces that all AffineDimExprs in the result are unique, so no
945 // further checks are needed.
946 return true;
947 }
948 // 2. If position in the expr (which is of type AffineDimExpr) is part
949 // of sequence, return false here. This implies the entire sequence does not
950 // exist in the indexing map.
951 if (sequenceElements.count(dimInMapStart))
952 return false;
953 }
954 // 3. No element of sequence found. Return true.
955 return true;
956 }
957
958 // Return the list of dimensions of the iteration domain that can be
959 // collapsed to allow for fusion with the a producer that is an expand_shape
960 // operation. If all dimensions created by expansion can be collapsed in the
961 // iteration space then the reshape is defunct.
962 //
963 // Example:
964 //
965 // ```mlir
966 // #map = affine_map<(d0, d1) -> (d0, d1)>
967 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
968 // %2 = linalg.init_tensor [..] : tensor<?x4xf32>
969 // %3 = linalg.generic {
970 // indexing_maps = [#map, #map],
971 // iterator_types = ["parallel" ,"parallel"]}
972 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
973 // ```
974 //
975 // can be fused by collapsing the dimensions of the iteration space.
976 //
977 // ```mlir
978 // #map = affine_map<(d0) -> (d0)>
979 // %2 = linalg.init_tensor [..] : tensor<?xf32>
980 // %3 = linalg.generic {
981 // indexing_maps = [#map, #map],
982 // iterator_types = ["parallel"]}
983 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
984 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
985 // ```
986 //
987 // In the following example,
988 //
989 // ```mlir
990 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
991 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
992 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
993 // %2 = linalg.init_tensor [..] : tensor<4x?xf32>
994 // %2 = linalg.generic {
995 // indexing_maps = [#map0, #map1],
996 // iterator_types = ["parallel" ,"parallel"]}
997 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
998 // ```
999 //
1000 // the reshape cannot be fused with the generic op by collapsing the op
1001 // dimensions since the indexing maps will have to contain mods and divs
1002 // to preserve the accesses pattern. When no dimensions of the iteration
1003 // space are collapsable and empty vector is returned.
1004 static SmallVector<ReassociationIndices>
getCollapsableIterationSpaceDims(GenericOp genericOp,OpOperand * fusableOperand,ArrayRef<ReassociationIndices> reassociation)1005 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1006 ArrayRef<ReassociationIndices> reassociation) {
1007 // Some basic checks for this fusion to be valid.
1008 if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1)
1009 return {};
1010
1011 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1012 return map.isProjectedPermutation();
1013 })) {
1014 return {};
1015 }
1016
1017 // Compute all the loops with the reduction iterator types.
1018 SmallVector<int64_t> reductionDims;
1019 for (const auto &iteratorType : llvm::enumerate(genericOp.iterator_types())) {
1020 if (isReductionIterator(iteratorType.value())) {
1021 reductionDims.push_back(iteratorType.index());
1022 }
1023 }
1024
1025 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1026 AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand);
1027 auto iteratorTypes = genericOp.iterator_types().getValue();
1028 SmallVector<ReassociationIndices> iterationSpaceReassociation;
1029 for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1030 assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1031
1032 // Ignore dims that are not folded.
1033 if (foldedRangeDims.size() == 1)
1034 continue;
1035
1036 ReassociationIndices foldedIterationSpaceDims =
1037 getDomainReassociation(indexingMap, foldedRangeDims);
1038
1039 // Check that the folded iteration dims do not contain already processed
1040 // dims.
1041 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1042 return processedIterationDims.count(dim);
1043 }))
1044 continue;
1045
1046 // Check that all folded iterator types are all parallel or all reductions.
1047 Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
1048 if (!isParallelIterator(startIteratorType) &&
1049 !isReductionIterator(startIteratorType))
1050 continue;
1051 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1052 return iteratorTypes[dim] != startIteratorType;
1053 }))
1054 continue;
1055
1056 // If the folded dimensions correspond to a "reduction" iterator type,
1057 // the folded dimensions need to be "in-order". Strictly speaking this is
1058 // not necessary, for reductions that are associative and commutative, but
1059 // using a more strict definition of reduction for now.
1060 if (isReductionIterator(startIteratorType)) {
1061 bool isContiguous = false;
1062 for (const auto &startDim : llvm::enumerate(reductionDims)) {
1063 // Move window in `reductionDims` to start of the folded iteration dims.
1064 if (startDim.value() != foldedIterationSpaceDims[0])
1065 continue;
1066 // If sizes doesnt match, trivial not contiguous. This condition should
1067 // not be hit.
1068 if (startDim.index() + foldedIterationSpaceDims.size() >
1069 reductionDims.size())
1070 break;
1071 // Check that the contiguity is maintained.
1072 isContiguous = true;
1073 for (const auto &foldedDim :
1074 llvm::enumerate(foldedIterationSpaceDims)) {
1075 if (reductionDims[foldedDim.index() + startDim.index()] !=
1076 foldedDim.value()) {
1077 isContiguous = false;
1078 break;
1079 }
1080 }
1081 break;
1082 }
1083 if (!isContiguous)
1084 continue;
1085 }
1086
1087 // Check that the sequence is preserved in all indexing maps.
1088 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1089 [&](AffineMap indexingMap) {
1090 return !isDimSequencePreserved(indexingMap,
1091 foldedIterationSpaceDims);
1092 }))
1093 continue;
1094
1095 processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1096 foldedIterationSpaceDims.end());
1097 iterationSpaceReassociation.emplace_back(
1098 std::move(foldedIterationSpaceDims));
1099 }
1100
1101 return iterationSpaceReassociation;
1102 }
1103
1104 /// Helper class to carry state while collapsing the `linalg.generic` op.
1105 namespace {
1106 class CollapsingInfo {
1107 public:
initialize(unsigned origNumLoops,ArrayRef<ReassociationIndices> foldedIterationDims)1108 LogicalResult initialize(unsigned origNumLoops,
1109 ArrayRef<ReassociationIndices> foldedIterationDims) {
1110 llvm::SmallDenseSet<int64_t, 4> processedDims;
1111 // Find all the dims that are folded.
1112 for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1113 if (foldedIterationDim.empty())
1114 continue;
1115 // If the folded dims contain dims already folded, that's illegal
1116 // specification. Repetition within a list is also illegal.
1117 for (auto dim : foldedIterationDim) {
1118 if (dim >= origNumLoops)
1119 return failure();
1120 if (processedDims.count(dim))
1121 return failure();
1122 processedDims.insert(dim);
1123 }
1124 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1125 foldedIterationDim.end());
1126 }
1127 if (processedDims.size() > origNumLoops)
1128 return failure();
1129
1130 // Add all the preserved dims of the original op as single
1131 // elements to `collapsedOpToOrigOpIterationDim`.
1132 for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1133 if (processedDims.count(dim))
1134 continue;
1135 collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1136 }
1137
1138 llvm::sort(collapsedOpToOrigOpIterationDim,
1139 [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
1140 return lhs[0] < rhs[0];
1141 });
1142 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1143 for (const auto &foldedDims :
1144 llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1145 for (const auto &dim : enumerate(foldedDims.value()))
1146 origOpToCollapsedOpIterationDim[dim.value()] =
1147 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1148 }
1149 return success();
1150 }
1151
1152 /// Return mapping from collapsed loop domain to original loop domain.
getCollapsedOpToOrigOpMapping() const1153 ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1154 return collapsedOpToOrigOpIterationDim;
1155 }
1156
1157 /// Return mapping from original loop domain to collapsed loop domain. The
1158 /// mapping is a pair. First value is the dimension in the collapsed loop that
1159 /// the original loop is mapped to. Second is the relative position in folded
1160 /// list of this domain. For example if the original loop domain is 3D, and
1161 /// the collapsed loop domain is folding all of it, i.e.
1162 ///
1163 /// ```
1164 /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1165 /// ```
1166 ///
1167 /// then
1168 ///
1169 /// ```
1170 /// origOpToCollapsedOpMapping[0] = {0, 0};
1171 /// origOpToCollapsedOpMapping[1] = {0, 1};
1172 /// origOpToCollapsedOpMapping[2] = {0, 2};
1173 /// origOpToCollapsedOpMapping[3] = {1, 0};
1174 /// origOpToCollapsedOpMapping[4] = {1, 1};
1175 /// ```
1176 ///
getOrigOpToCollapsedOpMapping() const1177 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1178 return origOpToCollapsedOpIterationDim;
1179 }
1180
1181 /// Return the collapsed op iteration domain rank.
getCollapsedOpIterationRank() const1182 unsigned getCollapsedOpIterationRank() const {
1183 return collapsedOpToOrigOpIterationDim.size();
1184 }
1185
1186 private:
1187 /// Map from the iteration domain index in collapsed op to the iteration
1188 /// domain indices in the original op.
1189 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1190
1191 /// Map from iteration domain index in the original op to the iteration domain
1192 /// index in the collapsed op.
1193 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1194 };
1195 } // namespace
1196
1197 /// Get the iterator types for the collapsed operation given the original
1198 /// iterator types and collapsed dimensions.
1199 static SmallVector<StringRef>
getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,const CollapsingInfo & collapsingInfo)1200 getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
1201 const CollapsingInfo &collapsingInfo) {
1202 SmallVector<StringRef> collapsedIteratorTypes;
1203 for (ReassociationIndicesRef foldedIterDims :
1204 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1205 assert(!foldedIterDims.empty() &&
1206 "reassociation indices expected to have non-empty sets");
1207 // Just pick the iterator type of the first folded dim. Pre-condition checks
1208 // expected to have checked that iterator types of all folded dimensions are
1209 // the same.
1210 collapsedIteratorTypes.push_back(
1211 iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue());
1212 }
1213 return collapsedIteratorTypes;
1214 }
1215
1216 /// Compute the indexing map in the collapsed op that corresponds to the given
1217 /// `indexingMap` of the original operation.
1218 static AffineMap
getCollapsedOpIndexingMap(AffineMap indexingMap,const CollapsingInfo & collapsingInfo)1219 getCollapsedOpIndexingMap(AffineMap indexingMap,
1220 const CollapsingInfo &collapsingInfo) {
1221 MLIRContext *context = indexingMap.getContext();
1222 assert(indexingMap.isProjectedPermutation() &&
1223 "expected indexing map to be projected permutation");
1224 SmallVector<AffineExpr> resultExprs;
1225 auto origOpToCollapsedOpMapping =
1226 collapsingInfo.getOrigOpToCollapsedOpMapping();
1227 for (auto expr : indexingMap.getResults()) {
1228 unsigned dim = expr.cast<AffineDimExpr>().getPosition();
1229 // If the dim is not the first of the collapsed dim, do nothing.
1230 if (origOpToCollapsedOpMapping[dim].second != 0)
1231 continue;
1232 // The next n-dims are guaranteed to be collapsed. So just use the
1233 // iteration dimension of the collapsed op.
1234 resultExprs.push_back(
1235 getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1236 }
1237 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1238 resultExprs, context);
1239 }
1240
1241 /// Return the `reassociation` indices to use to collapse the operand when the
1242 /// iteration space of a generic op is collapsed.
1243 static SmallVector<ReassociationIndices>
getOperandReassociation(AffineMap indexingMap,const CollapsingInfo & collapsingInfo)1244 getOperandReassociation(AffineMap indexingMap,
1245 const CollapsingInfo &collapsingInfo) {
1246 unsigned counter = 0;
1247 SmallVector<ReassociationIndices> operandReassociation;
1248 auto origOpToCollapsedOpMapping =
1249 collapsingInfo.getOrigOpToCollapsedOpMapping();
1250 auto collapsedOpToOrigOpMapping =
1251 collapsingInfo.getCollapsedOpToOrigOpMapping();
1252 while (counter < indexingMap.getNumResults()) {
1253 unsigned dim =
1254 indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition();
1255 if (origOpToCollapsedOpMapping[dim].second == 0) {
1256 // This is the start of a collapsed dimensions of the iteration that
1257 // is gauranteed to be preserved in the indexing map. The number of folded
1258 // dims is obtained from the collapsed op to original op mapping.
1259 unsigned numFoldedDims =
1260 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1261 .size();
1262 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1263 operandReassociation.emplace_back(range.begin(), range.end());
1264 counter += numFoldedDims;
1265 }
1266 }
1267 return operandReassociation;
1268 }
1269
1270 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
getCollapsedOpOperand(Location loc,GenericOp genericOp,OpOperand * opOperand,const CollapsingInfo & collapsingInfo,OpBuilder & builder)1271 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
1272 OpOperand *opOperand,
1273 const CollapsingInfo &collapsingInfo,
1274 OpBuilder &builder) {
1275 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
1276 SmallVector<ReassociationIndices> operandReassociation =
1277 getOperandReassociation(indexingMap, collapsingInfo);
1278
1279 // If the number of entries in the reassocation for the operand is same as the
1280 // number of results of the indexing map, then nothing to do for this operand.
1281 Value operand = opOperand->get();
1282 if (operandReassociation.size() == indexingMap.getNumResults())
1283 return operand;
1284
1285 // Insert a reshape to collapse the dimensions.
1286 auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
1287 loc, operand, operandReassociation);
1288 return reshapeOp.getResult();
1289 }
1290
1291 /// Modify the `linalg.index` operations in the original generic op, to its
1292 /// value in the collapsed operation.
generateCollapsedIndexingRegion(Location loc,Block * block,const CollapsingInfo & collapsingInfo,ValueRange loopRange,PatternRewriter & rewriter)1293 void generateCollapsedIndexingRegion(Location loc, Block *block,
1294 const CollapsingInfo &collapsingInfo,
1295 ValueRange loopRange,
1296 PatternRewriter &rewriter) {
1297 OpBuilder::InsertionGuard g(rewriter);
1298 rewriter.setInsertionPointToStart(block);
1299
1300 // Collect all the original index ops.
1301 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1302
1303 // For each folded dimension list resolve the original induction variable
1304 // values in terms of the folded dimension induction variable.
1305 // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1306 // can be inverted to
1307 // i2 = i_{folded} % d2
1308 // i1 = (i_{folded} / d2) % d1
1309 // i0 = i_{folded} / (d1 * d2)
1310 llvm::DenseMap<unsigned, Value> indexReplacementVals;
1311 for (auto &foldedDims :
1312 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1313 ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1314 Value newIndexVal =
1315 rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1316 for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1317 indexReplacementVals[dim] =
1318 rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1319 newIndexVal =
1320 rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1321 }
1322 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1323 }
1324
1325 for (auto indexOp : indexOps) {
1326 auto dim = indexOp.dim();
1327 rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1328 }
1329 }
1330
1331 /// Implementation of fusion with reshape operation by collapsing dimensions.
collapseGenericOpIterationDims(GenericOp genericOp,ArrayRef<ReassociationIndices> foldedIterationDims,OpOperand * fusableOpOperand,PatternRewriter & rewriter)1332 static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
1333 GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
1334 OpOperand *fusableOpOperand, PatternRewriter &rewriter) {
1335 // Bail on trivial no-op cases.
1336 if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1337 llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1338 return foldedDims.size() <= 1;
1339 }))
1340 return failure();
1341
1342 CollapsingInfo collapsingInfo;
1343 if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
1344 foldedIterationDims))) {
1345 return rewriter.notifyMatchFailure(
1346 genericOp, "illegal to collapse specified dimensions");
1347 }
1348
1349 // Get the iterator types for the operand.
1350 SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes(
1351 genericOp.iterator_types().getValue(), collapsingInfo);
1352
1353 // Get the indexing maps.
1354 auto indexingMaps = llvm::to_vector(
1355 llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
1356 return getCollapsedOpIndexingMap(map, collapsingInfo);
1357 }));
1358
1359 Location loc = genericOp->getLoc();
1360
1361 // Get the input operands.
1362 auto inputOperands = llvm::to_vector(
1363 llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) {
1364 return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
1365 rewriter);
1366 }));
1367
1368 // Get the output operands and result types.
1369 SmallVector<Type> resultTypes;
1370 SmallVector<Value> outputOperands;
1371 resultTypes.reserve(genericOp.getNumOutputs());
1372 outputOperands.reserve(genericOp.getNumOutputs());
1373 for (OpOperand *output : genericOp.getOutputOperands()) {
1374 Value newOutput =
1375 getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter);
1376 outputOperands.push_back(newOutput);
1377 resultTypes.push_back(newOutput.getType());
1378 }
1379
1380 // Create the generic op.
1381 auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
1382 loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1383 iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1384 Block *origOpBlock = &genericOp->getRegion(0).front();
1385 Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
1386 rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1387 collapsedOpBlock->getArguments());
1388
1389 if (collapsedGenericOp.hasIndexSemantics()) {
1390 // Collect the loop range of the generic op.
1391 OpBuilder::InsertionGuard g(rewriter);
1392 rewriter.setInsertionPoint(collapsedGenericOp);
1393 SmallVector<Range> loopRanges =
1394 cast<LinalgOp>(genericOp.getOperation())
1395 .createLoopRanges(rewriter, genericOp.getLoc());
1396 assert(llvm::all_of(loopRanges,
1397 [](Range range) {
1398 return matchPattern(range.offset, m_Zero()) &&
1399 matchPattern(range.stride, m_One());
1400 }) &&
1401 "expected all loop ranges to have zero start and unit stride");
1402 SmallVector<Value> loopBound = llvm::to_vector(
1403 llvm::map_range(loopRanges, [](Range range) { return range.size; }));
1404 generateCollapsedIndexingRegion(loc,
1405 &collapsedGenericOp->getRegion(0).front(),
1406 collapsingInfo, loopBound, rewriter);
1407 }
1408
1409 // Insert expanding reshape for the result to get back the original result
1410 // type.
1411 SmallVector<Value> results;
1412 for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
1413 Value collapsedOpResult =
1414 collapsedGenericOp->getResult(originalResult.index());
1415 auto originalResultType =
1416 originalResult.value().getType().cast<ShapedType>();
1417 auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
1418 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1419 AffineMap indexingMap =
1420 genericOp.getTiedIndexingMapForResult(originalResult.value());
1421 SmallVector<ReassociationIndices> reassociation =
1422 getOperandReassociation(indexingMap, collapsingInfo);
1423 Value result = rewriter.create<tensor::ExpandShapeOp>(
1424 loc, originalResultType, collapsedOpResult, reassociation);
1425 results.push_back(result);
1426 } else {
1427 results.push_back(collapsedOpResult);
1428 }
1429 }
1430 return results;
1431 }
1432
1433 namespace {
1434
1435 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1436 /// contracting dimensions of the loop.
1437 class FoldWithProducerReshapeOpByCollapsing
1438 : public OpRewritePattern<GenericOp> {
1439 public:
FoldWithProducerReshapeOpByCollapsing(MLIRContext * context,ControlFusionFn foldReshapes,PatternBenefit benefit=1)1440 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1441 ControlFusionFn foldReshapes,
1442 PatternBenefit benefit = 1)
1443 : OpRewritePattern<GenericOp>(context, benefit),
1444 controlFoldingReshapes(std::move(foldReshapes)) {}
1445
matchAndRewrite(GenericOp genericOp,PatternRewriter & rewriter) const1446 LogicalResult matchAndRewrite(GenericOp genericOp,
1447 PatternRewriter &rewriter) const override {
1448 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1449 tensor::ExpandShapeOp reshapeOp =
1450 opOperand->get().getDefiningOp<tensor::ExpandShapeOp>();
1451 if (!reshapeOp)
1452 continue;
1453
1454 SmallVector<ReassociationIndices> collapsableIterationDims =
1455 getCollapsableIterationSpaceDims(genericOp, opOperand,
1456 reshapeOp.getReassociationIndices());
1457 if (collapsableIterationDims.empty() ||
1458 !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
1459 continue;
1460 }
1461
1462 Optional<SmallVector<Value>> replacements =
1463 collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
1464 opOperand, rewriter);
1465 if (!replacements) {
1466 return rewriter.notifyMatchFailure(
1467 genericOp, "failed to do the fusion by collapsing transformation");
1468 }
1469
1470 rewriter.replaceOp(genericOp, *replacements);
1471 return success();
1472 }
1473 return failure();
1474 }
1475
1476 private:
1477 ControlFusionFn controlFoldingReshapes;
1478 };
1479 } // namespace
1480
1481 //===---------------------------------------------------------------------===//
1482 // Methods and patterns that fuse constants with linalg.generic operations.
1483 //===---------------------------------------------------------------------===//
1484
1485 namespace {
1486 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1487 /// handle cases where the constant is not single-valued.
1488 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1489 public:
FoldScalarOrSplatConstant(MLIRContext * context,PatternBenefit benefit=1)1490 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
1491 : OpRewritePattern<GenericOp>(context, benefit) {}
1492
matchAndRewrite(GenericOp genericOp,PatternRewriter & rewriter) const1493 LogicalResult matchAndRewrite(GenericOp genericOp,
1494 PatternRewriter &rewriter) const override {
1495 if (!genericOp.hasTensorSemantics())
1496 return failure();
1497 for (OpOperand *opOperand : genericOp.getInputOperands()) {
1498 Operation *def = opOperand->get().getDefiningOp();
1499 Attribute constantAttr;
1500 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1501 {
1502 DenseElementsAttr splatAttr;
1503 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1504 splatAttr.isSplat() &&
1505 splatAttr.getType().getElementType().isIntOrFloat()) {
1506 constantAttr = splatAttr.getSplatValue<Attribute>();
1507 return true;
1508 }
1509 }
1510 {
1511 IntegerAttr intAttr;
1512 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1513 constantAttr = intAttr;
1514 return true;
1515 }
1516 }
1517 {
1518 FloatAttr floatAttr;
1519 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1520 constantAttr = floatAttr;
1521 return true;
1522 }
1523 }
1524 return false;
1525 };
1526
1527 auto resultValue = opOperand->get().dyn_cast<OpResult>();
1528 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1529 continue;
1530
1531 // The operands and the indexing_maps of the fused operation the same as
1532 // the operands and indexing_maps of the generic operations with the
1533 // values at the constant index dropped.
1534 SmallVector<AffineMap> fusedIndexMaps;
1535 SmallVector<Value> fusedOperands;
1536 SmallVector<Location> fusedLocs{genericOp.getLoc()};
1537 fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1538 fusedOperands.reserve(genericOp.getNumInputs());
1539 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
1540 for (OpOperand *inputOperand : genericOp.getInputOperands()) {
1541 if (inputOperand == opOperand)
1542 continue;
1543 Value inputValue = inputOperand->get();
1544 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1545 fusedOperands.push_back(inputValue);
1546 fusedLocs.push_back(inputValue.getLoc());
1547 }
1548 for (OpOperand *outputOperand : genericOp.getOutputOperands())
1549 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1550
1551 // Check if the operation shapes to loops map is computable.
1552 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1553 return rewriter.notifyMatchFailure(
1554 genericOp, "fused op loop bound computation failed");
1555 }
1556
1557 // Create a constant scalar value from the splat constant.
1558 Value scalarConstant = rewriter.create<arith::ConstantOp>(
1559 def->getLoc(), constantAttr, constantAttr.getType());
1560
1561 SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1562 auto fusedOp = rewriter.create<GenericOp>(
1563 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1564 /*inputs=*/fusedOperands,
1565 /*outputs=*/outputOperands,
1566 rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1567 genericOp.iterator_types(),
1568 /*doc=*/nullptr,
1569 /*library_call=*/nullptr);
1570
1571 // Map the block argument corresponding to the replaced argument with the
1572 // scalar constant.
1573 Region ®ion = genericOp->getRegion(0);
1574 Block &entryBlock = *region.begin();
1575 BlockAndValueMapping mapping;
1576 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1577 scalarConstant);
1578 Region &fusedRegion = fusedOp->getRegion(0);
1579 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1580 mapping);
1581 rewriter.replaceOp(genericOp, fusedOp->getResults());
1582 return success();
1583 }
1584 return failure();
1585 }
1586 };
1587
1588 } // namespace
1589
1590 //===---------------------------------------------------------------------===//
1591 // Miscellaneous patterns that help fusion.
1592 //===---------------------------------------------------------------------===//
1593
1594 namespace {
1595 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
1596 /// the value of the `outs` operand is not used within the op. This is only
1597 /// implemented for `linalg.generic` operations for now, but should hold for all
1598 /// linalg structured ops.
1599 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1600 using OpRewritePattern<GenericOp>::OpRewritePattern;
1601
matchAndRewrite__anon750e3d171d11::RemoveOutsDependency1602 LogicalResult matchAndRewrite(GenericOp op,
1603 PatternRewriter &rewriter) const override {
1604 rewriter.startRootUpdate(op);
1605 bool modifiedOutput = false;
1606 Location loc = op.getLoc();
1607 for (OpOperand *opOperand : op.getOutputOperands()) {
1608 if (!op.payloadUsesValueFromOperand(opOperand)) {
1609 Value operandVal = opOperand->get();
1610 auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
1611 if (!operandType)
1612 continue;
1613
1614 // If outs is sparse, leave it to the sparse compiler.
1615 if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
1616 continue;
1617
1618 // If outs is already an `init_tensor` operation, nothing to do.
1619 auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
1620 if (definingOp)
1621 continue;
1622 modifiedOutput = true;
1623 SmallVector<Value> dynamicDims;
1624 for (const auto &dim : llvm::enumerate(operandType.getShape())) {
1625 if (dim.value() != ShapedType::kDynamicSize)
1626 continue;
1627 dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
1628 loc, operandVal, dim.index()));
1629 }
1630 Value initTensor = rewriter.create<InitTensorOp>(
1631 loc, dynamicDims, operandType.getShape(),
1632 operandType.getElementType());
1633 op->setOperand(opOperand->getOperandNumber(), initTensor);
1634 }
1635 }
1636 if (!modifiedOutput) {
1637 rewriter.cancelRootUpdate(op);
1638 return failure();
1639 }
1640 rewriter.finalizeRootUpdate(op);
1641 return success();
1642 }
1643 };
1644
1645 /// Fold linalg.fill into linalg.generic
1646 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
1647 using OpRewritePattern<GenericOp>::OpRewritePattern;
1648
matchAndRewrite__anon750e3d171d11::FoldFillWithGenericOp1649 LogicalResult matchAndRewrite(GenericOp genericOp,
1650 PatternRewriter &rewriter) const override {
1651 if (!genericOp.hasTensorSemantics())
1652 return failure();
1653 bool fillFound = false;
1654 Block &payload = genericOp.region().front();
1655 for (OpOperand *opOperand : genericOp.getInputOperands()) {
1656 if (!genericOp.payloadUsesValueFromOperand(opOperand))
1657 continue;
1658 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
1659 if (!fillOp)
1660 continue;
1661 fillFound = true;
1662 payload.getArgument(opOperand->getOperandNumber())
1663 .replaceAllUsesWith(fillOp.value());
1664 }
1665 return success(fillFound);
1666 }
1667 };
1668 } // namespace
1669
populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet & patterns,const ControlFusionFn & controlFoldingReshapes)1670 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
1671 RewritePatternSet &patterns,
1672 const ControlFusionFn &controlFoldingReshapes) {
1673 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
1674 controlFoldingReshapes);
1675 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
1676 controlFoldingReshapes);
1677 }
1678
populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet & patterns,const ControlFusionFn & controlFoldingReshapes)1679 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
1680 RewritePatternSet &patterns,
1681 const ControlFusionFn &controlFoldingReshapes) {
1682 patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
1683 controlFoldingReshapes);
1684 }
1685
populateElementwiseOpsFusionPatterns(RewritePatternSet & patterns,const ControlFusionFn & controlElementwiseOpsFusion)1686 void mlir::linalg::populateElementwiseOpsFusionPatterns(
1687 RewritePatternSet &patterns,
1688 const ControlFusionFn &controlElementwiseOpsFusion) {
1689 auto *context = patterns.getContext();
1690 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
1691 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
1692 RemoveOutsDependency>(context);
1693 }
1694
1695 //===---------------------------------------------------------------------===//
1696 // Passes
1697 //===---------------------------------------------------------------------===//
1698
1699 namespace {
1700
1701 /// Pass that fuses generic ops on tensors. Used only for testing.
1702 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
1703 // patterns added here heavily depends on the cost function used. Having an
1704 // opinionated pass of this form is not recommended. Deprecate this pass in
1705 // favor of test passes that check the functionality of each of the patterns
1706 // added here individually.
1707 struct LinalgElementwiseOpFusionPass
1708 : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
runOnOperation__anon750e3d171e11::LinalgElementwiseOpFusionPass1709 void runOnOperation() override {
1710 Operation *op = getOperation();
1711 MLIRContext *context = op->getContext();
1712 RewritePatternSet patterns(context);
1713
1714 // Add folding with reshape by expansion patterns.
1715 ControlFusionFn defaultControlFn = [](const OpResult &producer,
1716 const OpOperand &consumer) {
1717 return producer.hasOneUse();
1718 };
1719
1720 // Add elementwise op fusion patterns.
1721 populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
1722 populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
1723
1724 // General canonicalization patterns.
1725 AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1726 GenericOp::getCanonicalizationPatterns(patterns, context);
1727 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1728 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1729 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1730 patterns);
1731
1732 // Add constant folding patterns.
1733 populateConstantFoldLinalgOperations(patterns, defaultControlFn);
1734
1735 // Use TopDownTraversal for compile time reasons
1736 GreedyRewriteConfig grc;
1737 grc.useTopDownTraversal = true;
1738 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
1739 grc);
1740 }
1741 };
1742
1743 } // namespace
1744
createLinalgElementwiseOpFusionPass()1745 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
1746 return std::make_unique<LinalgElementwiseOpFusionPass>();
1747 }
1748