1 //===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 
14 using namespace mlir;
15 using namespace mlir::linalg;
16 
17 namespace {
18 
19 /// Pattern to decompose a GenericOp that has more than two statements
20 /// into one GenericOp with the first statement (i.e. peeled operation), and
21 /// a second GenericOp with the remaining statements (i.e. residual operations).
22 
23 /// - The result of the first GenericOp has the same shape as the iteration
24 ///   space of the GenericOp. The body of the op yields as many values as the
25 ///   original op plus all the results of the peeled operation.
26 /// - The second GenericOp has as many operands as the original operation plus
27 /// all the results of the first Generic Op. It has the same number of yields as
28 /// the original op.
29 /// - If the result of the peeled operation was yielded by the original
30 ///   GenericOp the uses of the corresponding results will be replaced with the
31 ///   result of the first GenericOp created.
32 ///
33 ///  Example
34 ///
35 /// ```mlir
36 ///  %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
37 ///      outs(%init0, %init1 : ...) {
38 ///    ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...):
39 ///      %0 = <s0> %b0, %b1 : ...
40 ///      %1 = <s1> %0, %b2 : ...
41 ///      linalg.yield %0, %1 : ...
42 ///  } -> (..., ...)
43 ///  return %result#0, %result#1
44 /// ```
45 ///
46 /// gets split into
47 ///
48 /// ```mlir
49 /// %init = linalg.init_tensor ...
50 /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
51 ///      outs(%init0, %init1, %init : ...)
52 ///    ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
53 ///      %0 = <s0> %b0, %b1 : ...
54 ///      linalg.yield %0, %..., %0 : ...
55 ///  } -> (..., ..., ...)
56 /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...)
57 ///      outs(%init0, %init1 : ...) {
58 ///    ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
59 ///      %1 = <s1> %b3, %b2 : ...
60 ///      linalg.yield %..., %1 : ...
61 ///  } -> (..., ...)
62 ///  return %op0#0, %op1#1
63 /// ```
64 ///
65 /// After canonicalization this is expected to be
66 ///
67 /// ```mlir
68 /// %init = linalg.init_tensor ...
69 /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...)
70 ///      outs(%init : ...)
71 ///    ^bb0(%b0: ... , %b1: ... , %b2: ...):
72 ///      %0 = <s0> %b0, %b1 : ...
73 ///      linalg.yield %0 : ...
74 ///  } -> ...
75 /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...)
76 ///      outs(%init1 : ...) {
77 ///    ^bb0(%b0: ... , %b1: ... , %b2: ...):
78 ///      %1 = <s1> %b1, %b0 : ...
79 ///      linalg.yield %..., %1 : ...
80 ///  } -> ...
81 ///  return %op0, %op1
82 /// ```
83 struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
84   using OpRewritePattern<GenericOp>::OpRewritePattern;
85 
86   LogicalResult matchAndRewrite(GenericOp genericOp,
87                                 PatternRewriter &rewriter) const override;
88 
89 private:
90   /// Helper method to create a generic op for the peeled scalar operation. The
91   /// created op has an empty region.
92   GenericOp createPeeledGenericOp(GenericOp genericOp,
93                                   PatternRewriter &rewriter) const;
94 
95   /// Helper method to create a generic op for the residual scalar operation.
96   /// The created op has the same region as the original op.
97   GenericOp createResidualGenericOp(GenericOp genericOp,
98                                     GenericOp peeledGenericOp,
99                                     PatternRewriter &rewriter) const;
100 };
101 } // namespace
102 
103 /// Helper method to compute the range of a generic op.
getGenericOpLoopRange(OpBuilder & b,GenericOp op)104 static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b,
105                                                        GenericOp op) {
106   OpBuilder::InsertionGuard g(b);
107   b.setInsertionPoint(op);
108   Location loc = op.getLoc();
109   auto allShapesSizes =
110       cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
111   AffineMap map = op.getShapesToLoopsMap();
112   return getAsOpFoldResult(applyMapToValues(b, loc, map, allShapesSizes));
113 }
114 
115 /// Helper method to permute the list of `values` based on the `map`.
permuteValues(ArrayRef<OpFoldResult> values,AffineMap map)116 SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values,
117                                         AffineMap map) {
118   assert(map.isPermutation());
119   SmallVector<OpFoldResult> permutedValues(values.size());
120   for (const auto &position :
121        llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) {
122          return expr.cast<AffineDimExpr>().getPosition();
123        })))
124     permutedValues[position.value()] = values[position.index()];
125   return permutedValues;
126 }
127 
128 /// Get zero value for an element type.
getZero(OpBuilder & b,Location loc,Type elementType)129 static Value getZero(OpBuilder &b, Location loc, Type elementType) {
130   assert(elementType.isIntOrIndexOrFloat() &&
131          "expected scalar type while computing zero value");
132   if (elementType.isa<IntegerType>())
133     return b.create<arith::ConstantIntOp>(loc, 0, elementType);
134   if (elementType.isIndex())
135     return b.create<arith::ConstantIndexOp>(loc, 0);
136   // Assume float.
137   auto floatType = elementType.cast<FloatType>();
138   return b.create<arith::ConstantFloatOp>(
139       loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
140 }
141 
142 GenericOp
createPeeledGenericOp(GenericOp genericOp,PatternRewriter & rewriter) const143 DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
144                                          PatternRewriter &rewriter) const {
145   Block *body = genericOp.getBody();
146   Operation *peeledScalarOperation = &(*body->begin());
147   SmallVector<AffineMap> peeledGenericOpIndexingMaps =
148       genericOp.getIndexingMapsArray();
149 
150   /// Compute the loop ranges for operation. This is the shape of the result of
151   /// the generic op for the peeled operation.
152   Location loc = genericOp.getLoc();
153   SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
154   SmallVector<Value> newInitValues;
155   SmallVector<Type> newResultTypes;
156 
157   /// The indexing map to use for the new results is obtained by
158   /// - Check if the result is yielded. If so use the same indexing map as the
159   /// corresponding output
160   /// - Identity indexing map if the result is not yielded.
161   Operation *yieldOp = body->getTerminator();
162   auto getResultIndexingMap = [&](OpResult scalarOpResult) -> AffineMap {
163     OpOperand *firstUseInYield = nullptr, *identityUseInYield = nullptr;
164     for (OpOperand &use : scalarOpResult.getUses()) {
165       if (use.getOwner() != yieldOp)
166         continue;
167       if (!firstUseInYield)
168         firstUseInYield = &use;
169       OpResult genericOpResult =
170           genericOp.getResult(use.getOperandNumber()).cast<OpResult>();
171       AffineMap indexingMap =
172           genericOp.getTiedIndexingMapForResult(genericOpResult);
173       if (indexingMap.isIdentity())
174         identityUseInYield = &use;
175     }
176     if (identityUseInYield || !firstUseInYield)
177       return rewriter.getMultiDimIdentityMap(domain.size());
178     OpResult genericOpResult =
179         genericOp.getResult(firstUseInYield->getOperandNumber())
180             .cast<OpResult>();
181     return genericOp.getTiedIndexingMapForResult(genericOpResult);
182   };
183 
184   for (auto scalarResult : peeledScalarOperation->getResults()) {
185     AffineMap resultIndexingMap = getResultIndexingMap(scalarResult);
186     SmallVector<OpFoldResult> initSize =
187         permuteValues(domain, resultIndexingMap);
188     Value initTensor = rewriter.create<linalg::InitTensorOp>(
189         loc, initSize, scalarResult.getType());
190     newInitValues.push_back(initTensor);
191     newResultTypes.push_back(initTensor.getType());
192     peeledGenericOpIndexingMaps.push_back(resultIndexingMap);
193   }
194 
195   /// Create the peeled generic op with an empty body.
196   SmallVector<Value> outsOperands = genericOp.getOutputOperands();
197   outsOperands.append(newInitValues.begin(), newInitValues.end());
198   SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
199   resultTypes.append(newResultTypes.begin(), newResultTypes.end());
200   auto indexingMapAttr =
201       rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
202   return rewriter.create<GenericOp>(
203       loc, resultTypes, genericOp.inputs(), outsOperands, indexingMapAttr,
204       genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
205       [](OpBuilder, Location, ValueRange) {});
206 }
207 
208 GenericOp
createResidualGenericOp(GenericOp genericOp,GenericOp peeledGenericOp,PatternRewriter & rewriter) const209 DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
210                                            GenericOp peeledGenericOp,
211                                            PatternRewriter &rewriter) const {
212   /// Append all results from the peeledGenericOps as `ins` operand for the
213   /// residual generic op.
214   SmallVector<Value> residualGenericOpOperands = llvm::to_vector(
215       llvm::map_range(genericOp.getInputOperands(),
216                       [](OpOperand *operand) { return operand->get(); }));
217   unsigned origNumResults = genericOp.getNumResults();
218   unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
219   SmallVector<Value> extraIns;
220   for (auto resultNum :
221        llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
222     extraIns.push_back(peeledGenericOp->getResult(resultNum));
223   residualGenericOpOperands.append(extraIns);
224 
225   /// Add indexing maps for the newly added operands. Use the same map
226   /// as those used for the new results of the peeledGenericOp.
227   auto indexingMaps = llvm::to_vector(
228       llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) {
229         return genericOp.getTiedIndexingMap(operand);
230       }));
231   for (auto resultNum :
232        llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
233     OpResult result = peeledGenericOp.getResult(resultNum).cast<OpResult>();
234     indexingMaps.push_back(peeledGenericOp.getTiedIndexingMapForResult(result));
235   }
236   for (OpOperand *outOperand : genericOp.getOutputOperands())
237     indexingMaps.push_back(genericOp.getTiedIndexingMap(outOperand));
238 
239   auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
240   return rewriter.create<GenericOp>(
241       genericOp->getLoc(), genericOp->getResultTypes(),
242       residualGenericOpOperands, genericOp.outputs(), indexingMapAttr,
243       genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
244       [](OpBuilder, Location, ValueRange) {});
245 }
246 
247 LogicalResult
matchAndRewrite(GenericOp genericOp,PatternRewriter & rewriter) const248 DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
249                                    PatternRewriter &rewriter) const {
250   /// For now only match on operations where the iterator types are all parallel
251   if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
252     return rewriter.notifyMatchFailure(genericOp,
253                                        "unhandled decomposition of operation "
254                                        "with non-parallel iterator types");
255   }
256   // TODO: this could be generalized to handle `linalg.generic` with buffer
257   // operands too but requires allocation for intermediates. Punt on this for
258   // now.
259   if (!genericOp.hasTensorSemantics()) {
260     return rewriter.notifyMatchFailure(
261         genericOp, "only operations with tensor semantics are handled");
262   }
263 
264   // TODO: For now only decompose operations where the `outs` operands values
265   // are not accessed within the payload. This might be relaxed in future, but
266   // needs a bit more reasoning to ensure that it is safe.
267   if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) {
268         return genericOp.payloadUsesValueFromOperand(outOperand);
269       })) {
270     return rewriter.notifyMatchFailure(
271         genericOp, "unhandled decomposition of generic op with use of out "
272                    "operand value in payload");
273   }
274 
275   if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) {
276         return !genericOp.getTiedIndexingMap(outOperand).isPermutation();
277       })) {
278     return rewriter.notifyMatchFailure(
279         genericOp, "unhandled decomposition of generic op with out operand not "
280                    "accessed using a permutation");
281   }
282 
283   /// If the op has only a single statement (apart from the yield), do nothing.
284   Block *body = genericOp.getBody();
285   if (body->getOperations().size() <= 2) {
286     return rewriter.notifyMatchFailure(genericOp,
287                                        "operation has less than 3 statements");
288   }
289 
290   /// Check that the peeled statement has a scalar element type.
291   if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
292                    [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
293     return rewriter.notifyMatchFailure(
294         &(*body->getOperations().begin()),
295         "expected return type to be only int, index or float");
296   }
297 
298   GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
299   GenericOp residualGenericOp =
300       createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
301 
302   /// Move the first statement of the original operation into the body of the
303   /// generic op for the peeled operation.
304   Block *peeledGenericOpBody = peeledGenericOp.getBody();
305   Block *residualGenericOpBody = residualGenericOp.getBody();
306   assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
307          "expected split generic ops to have empty region");
308   peeledGenericOpBody->getOperations().splice(
309       peeledGenericOpBody->begin(), body->getOperations(), body->begin());
310   residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
311                                                 body->getOperations());
312 
313   Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
314   auto yieldOp = residualGenericOpBody->getTerminator();
315   {
316     // Yield all the result of the peeled scalar operation.
317     OpBuilder::InsertionGuard g(rewriter);
318     rewriter.setInsertionPointToEnd(peeledGenericOpBody);
319     SmallVector<Value> yieldedVals;
320     for (auto origYield : yieldOp->getOperands()) {
321       if (origYield.getDefiningOp() == peeledScalarOperation) {
322         yieldedVals.push_back(origYield);
323       } else {
324         yieldedVals.push_back(
325             getZero(rewriter, genericOp.getLoc(), origYield.getType()));
326       }
327     }
328     yieldedVals.append(llvm::to_vector(
329         llvm::map_range(peeledScalarOperation->getResults(),
330                         [](OpResult opr) -> Value { return opr; })));
331     rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals);
332   }
333 
334   /// In the split operations, replace block arguments uses that refer to
335   /// original operation to the block arguments of the newly created operation.
336   unsigned origNumInputs = genericOp.getNumInputs();
337   for (const auto &inputBlockArg :
338        llvm::enumerate(genericOp.getBody()->getArguments())) {
339     Value residualOpReplacementArg =
340         residualGenericOpBody->getArgument(inputBlockArg.index());
341     inputBlockArg.value().replaceUsesWithIf(
342         residualOpReplacementArg, [&](OpOperand &use) {
343           return use.getOwner()->getBlock() == residualGenericOpBody;
344         });
345 
346     Value peeledOpReplacementArg =
347         peeledGenericOpBody->getArgument(inputBlockArg.index());
348     inputBlockArg.value().replaceUsesWithIf(
349         peeledOpReplacementArg, [&](OpOperand &use) {
350           return use.getOwner()->getBlock() == peeledGenericOpBody;
351         });
352   }
353 
354   /// Before fixing up the residual operation, track what values are yielded. If
355   /// any of those are from the peeled scalar operation, the uses of the
356   /// corresponding result have to be remapped to result of the generic op for
357   /// the peeled operation.
358   SmallVector<Value> replacements;
359   for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
360     OpResult opr = yieldValue.value().dyn_cast<OpResult>();
361     if (!opr || opr.getOwner() != peeledScalarOperation)
362       replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
363     else
364       replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
365   }
366 
367   /// Update all uses of the peeled scalar operation results in the residual op
368   /// to the newly added arguments.
369   {
370     SmallVector<Value> scalarReplacements;
371     unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
372     scalarReplacements.reserve(peeledScalarOpNumResults);
373     for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
374       scalarReplacements.push_back(
375           residualGenericOpBody->getArgument(num + origNumInputs));
376     bool allUsesReplaced = false;
377     rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
378                                   residualGenericOpBody, &allUsesReplaced);
379     assert(!allUsesReplaced &&
380            "peeled scalar operation is erased when it wasnt expected to be");
381   }
382 
383   // Replace the original operation
384   rewriter.replaceOp(genericOp, replacements);
385   return success();
386 }
387 
populateDecomposeLinalgOpsPattern(RewritePatternSet & patterns)388 void mlir::linalg::populateDecomposeLinalgOpsPattern(
389     RewritePatternSet &patterns) {
390   patterns.insert<DecomposeLinalgOp>(patterns.getContext());
391 }
392