1 //===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13
14 using namespace mlir;
15 using namespace mlir::linalg;
16
17 namespace {
18
19 /// Pattern to decompose a GenericOp that has more than two statements
20 /// into one GenericOp with the first statement (i.e. peeled operation), and
21 /// a second GenericOp with the remaining statements (i.e. residual operations).
22
23 /// - The result of the first GenericOp has the same shape as the iteration
24 /// space of the GenericOp. The body of the op yields as many values as the
25 /// original op plus all the results of the peeled operation.
26 /// - The second GenericOp has as many operands as the original operation plus
27 /// all the results of the first Generic Op. It has the same number of yields as
28 /// the original op.
29 /// - If the result of the peeled operation was yielded by the original
30 /// GenericOp the uses of the corresponding results will be replaced with the
31 /// result of the first GenericOp created.
32 ///
33 /// Example
34 ///
35 /// ```mlir
36 /// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
37 /// outs(%init0, %init1 : ...) {
38 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...):
39 /// %0 = <s0> %b0, %b1 : ...
40 /// %1 = <s1> %0, %b2 : ...
41 /// linalg.yield %0, %1 : ...
42 /// } -> (..., ...)
43 /// return %result#0, %result#1
44 /// ```
45 ///
46 /// gets split into
47 ///
48 /// ```mlir
49 /// %init = linalg.init_tensor ...
50 /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
51 /// outs(%init0, %init1, %init : ...)
52 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
53 /// %0 = <s0> %b0, %b1 : ...
54 /// linalg.yield %0, %..., %0 : ...
55 /// } -> (..., ..., ...)
56 /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...)
57 /// outs(%init0, %init1 : ...) {
58 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
59 /// %1 = <s1> %b3, %b2 : ...
60 /// linalg.yield %..., %1 : ...
61 /// } -> (..., ...)
62 /// return %op0#0, %op1#1
63 /// ```
64 ///
65 /// After canonicalization this is expected to be
66 ///
67 /// ```mlir
68 /// %init = linalg.init_tensor ...
69 /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...)
70 /// outs(%init : ...)
71 /// ^bb0(%b0: ... , %b1: ... , %b2: ...):
72 /// %0 = <s0> %b0, %b1 : ...
73 /// linalg.yield %0 : ...
74 /// } -> ...
75 /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...)
76 /// outs(%init1 : ...) {
77 /// ^bb0(%b0: ... , %b1: ... , %b2: ...):
78 /// %1 = <s1> %b1, %b0 : ...
79 /// linalg.yield %..., %1 : ...
80 /// } -> ...
81 /// return %op0, %op1
82 /// ```
83 struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
84 using OpRewritePattern<GenericOp>::OpRewritePattern;
85
86 LogicalResult matchAndRewrite(GenericOp genericOp,
87 PatternRewriter &rewriter) const override;
88
89 private:
90 /// Helper method to create a generic op for the peeled scalar operation. The
91 /// created op has an empty region.
92 GenericOp createPeeledGenericOp(GenericOp genericOp,
93 PatternRewriter &rewriter) const;
94
95 /// Helper method to create a generic op for the residual scalar operation.
96 /// The created op has the same region as the original op.
97 GenericOp createResidualGenericOp(GenericOp genericOp,
98 GenericOp peeledGenericOp,
99 PatternRewriter &rewriter) const;
100 };
101 } // namespace
102
103 /// Helper method to compute the range of a generic op.
getGenericOpLoopRange(OpBuilder & b,GenericOp op)104 static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b,
105 GenericOp op) {
106 OpBuilder::InsertionGuard g(b);
107 b.setInsertionPoint(op);
108 Location loc = op.getLoc();
109 auto allShapesSizes =
110 cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
111 AffineMap map = op.getShapesToLoopsMap();
112 return getAsOpFoldResult(applyMapToValues(b, loc, map, allShapesSizes));
113 }
114
115 /// Helper method to permute the list of `values` based on the `map`.
permuteValues(ArrayRef<OpFoldResult> values,AffineMap map)116 SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values,
117 AffineMap map) {
118 assert(map.isPermutation());
119 SmallVector<OpFoldResult> permutedValues(values.size());
120 for (const auto &position :
121 llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) {
122 return expr.cast<AffineDimExpr>().getPosition();
123 })))
124 permutedValues[position.value()] = values[position.index()];
125 return permutedValues;
126 }
127
128 /// Get zero value for an element type.
getZero(OpBuilder & b,Location loc,Type elementType)129 static Value getZero(OpBuilder &b, Location loc, Type elementType) {
130 assert(elementType.isIntOrIndexOrFloat() &&
131 "expected scalar type while computing zero value");
132 if (elementType.isa<IntegerType>())
133 return b.create<arith::ConstantIntOp>(loc, 0, elementType);
134 if (elementType.isIndex())
135 return b.create<arith::ConstantIndexOp>(loc, 0);
136 // Assume float.
137 auto floatType = elementType.cast<FloatType>();
138 return b.create<arith::ConstantFloatOp>(
139 loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
140 }
141
142 GenericOp
createPeeledGenericOp(GenericOp genericOp,PatternRewriter & rewriter) const143 DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
144 PatternRewriter &rewriter) const {
145 Block *body = genericOp.getBody();
146 Operation *peeledScalarOperation = &(*body->begin());
147 SmallVector<AffineMap> peeledGenericOpIndexingMaps =
148 genericOp.getIndexingMapsArray();
149
150 /// Compute the loop ranges for operation. This is the shape of the result of
151 /// the generic op for the peeled operation.
152 Location loc = genericOp.getLoc();
153 SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
154 SmallVector<Value> newInitValues;
155 SmallVector<Type> newResultTypes;
156
157 /// The indexing map to use for the new results is obtained by
158 /// - Check if the result is yielded. If so use the same indexing map as the
159 /// corresponding output
160 /// - Identity indexing map if the result is not yielded.
161 Operation *yieldOp = body->getTerminator();
162 auto getResultIndexingMap = [&](OpResult scalarOpResult) -> AffineMap {
163 OpOperand *firstUseInYield = nullptr, *identityUseInYield = nullptr;
164 for (OpOperand &use : scalarOpResult.getUses()) {
165 if (use.getOwner() != yieldOp)
166 continue;
167 if (!firstUseInYield)
168 firstUseInYield = &use;
169 OpResult genericOpResult =
170 genericOp.getResult(use.getOperandNumber()).cast<OpResult>();
171 AffineMap indexingMap =
172 genericOp.getTiedIndexingMapForResult(genericOpResult);
173 if (indexingMap.isIdentity())
174 identityUseInYield = &use;
175 }
176 if (identityUseInYield || !firstUseInYield)
177 return rewriter.getMultiDimIdentityMap(domain.size());
178 OpResult genericOpResult =
179 genericOp.getResult(firstUseInYield->getOperandNumber())
180 .cast<OpResult>();
181 return genericOp.getTiedIndexingMapForResult(genericOpResult);
182 };
183
184 for (auto scalarResult : peeledScalarOperation->getResults()) {
185 AffineMap resultIndexingMap = getResultIndexingMap(scalarResult);
186 SmallVector<OpFoldResult> initSize =
187 permuteValues(domain, resultIndexingMap);
188 Value initTensor = rewriter.create<linalg::InitTensorOp>(
189 loc, initSize, scalarResult.getType());
190 newInitValues.push_back(initTensor);
191 newResultTypes.push_back(initTensor.getType());
192 peeledGenericOpIndexingMaps.push_back(resultIndexingMap);
193 }
194
195 /// Create the peeled generic op with an empty body.
196 SmallVector<Value> outsOperands = genericOp.getOutputOperands();
197 outsOperands.append(newInitValues.begin(), newInitValues.end());
198 SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
199 resultTypes.append(newResultTypes.begin(), newResultTypes.end());
200 auto indexingMapAttr =
201 rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
202 return rewriter.create<GenericOp>(
203 loc, resultTypes, genericOp.inputs(), outsOperands, indexingMapAttr,
204 genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
205 [](OpBuilder, Location, ValueRange) {});
206 }
207
208 GenericOp
createResidualGenericOp(GenericOp genericOp,GenericOp peeledGenericOp,PatternRewriter & rewriter) const209 DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
210 GenericOp peeledGenericOp,
211 PatternRewriter &rewriter) const {
212 /// Append all results from the peeledGenericOps as `ins` operand for the
213 /// residual generic op.
214 SmallVector<Value> residualGenericOpOperands = llvm::to_vector(
215 llvm::map_range(genericOp.getInputOperands(),
216 [](OpOperand *operand) { return operand->get(); }));
217 unsigned origNumResults = genericOp.getNumResults();
218 unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
219 SmallVector<Value> extraIns;
220 for (auto resultNum :
221 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
222 extraIns.push_back(peeledGenericOp->getResult(resultNum));
223 residualGenericOpOperands.append(extraIns);
224
225 /// Add indexing maps for the newly added operands. Use the same map
226 /// as those used for the new results of the peeledGenericOp.
227 auto indexingMaps = llvm::to_vector(
228 llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) {
229 return genericOp.getTiedIndexingMap(operand);
230 }));
231 for (auto resultNum :
232 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
233 OpResult result = peeledGenericOp.getResult(resultNum).cast<OpResult>();
234 indexingMaps.push_back(peeledGenericOp.getTiedIndexingMapForResult(result));
235 }
236 for (OpOperand *outOperand : genericOp.getOutputOperands())
237 indexingMaps.push_back(genericOp.getTiedIndexingMap(outOperand));
238
239 auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
240 return rewriter.create<GenericOp>(
241 genericOp->getLoc(), genericOp->getResultTypes(),
242 residualGenericOpOperands, genericOp.outputs(), indexingMapAttr,
243 genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
244 [](OpBuilder, Location, ValueRange) {});
245 }
246
247 LogicalResult
matchAndRewrite(GenericOp genericOp,PatternRewriter & rewriter) const248 DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
249 PatternRewriter &rewriter) const {
250 /// For now only match on operations where the iterator types are all parallel
251 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
252 return rewriter.notifyMatchFailure(genericOp,
253 "unhandled decomposition of operation "
254 "with non-parallel iterator types");
255 }
256 // TODO: this could be generalized to handle `linalg.generic` with buffer
257 // operands too but requires allocation for intermediates. Punt on this for
258 // now.
259 if (!genericOp.hasTensorSemantics()) {
260 return rewriter.notifyMatchFailure(
261 genericOp, "only operations with tensor semantics are handled");
262 }
263
264 // TODO: For now only decompose operations where the `outs` operands values
265 // are not accessed within the payload. This might be relaxed in future, but
266 // needs a bit more reasoning to ensure that it is safe.
267 if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) {
268 return genericOp.payloadUsesValueFromOperand(outOperand);
269 })) {
270 return rewriter.notifyMatchFailure(
271 genericOp, "unhandled decomposition of generic op with use of out "
272 "operand value in payload");
273 }
274
275 if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) {
276 return !genericOp.getTiedIndexingMap(outOperand).isPermutation();
277 })) {
278 return rewriter.notifyMatchFailure(
279 genericOp, "unhandled decomposition of generic op with out operand not "
280 "accessed using a permutation");
281 }
282
283 /// If the op has only a single statement (apart from the yield), do nothing.
284 Block *body = genericOp.getBody();
285 if (body->getOperations().size() <= 2) {
286 return rewriter.notifyMatchFailure(genericOp,
287 "operation has less than 3 statements");
288 }
289
290 /// Check that the peeled statement has a scalar element type.
291 if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
292 [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
293 return rewriter.notifyMatchFailure(
294 &(*body->getOperations().begin()),
295 "expected return type to be only int, index or float");
296 }
297
298 GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
299 GenericOp residualGenericOp =
300 createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
301
302 /// Move the first statement of the original operation into the body of the
303 /// generic op for the peeled operation.
304 Block *peeledGenericOpBody = peeledGenericOp.getBody();
305 Block *residualGenericOpBody = residualGenericOp.getBody();
306 assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
307 "expected split generic ops to have empty region");
308 peeledGenericOpBody->getOperations().splice(
309 peeledGenericOpBody->begin(), body->getOperations(), body->begin());
310 residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
311 body->getOperations());
312
313 Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
314 auto yieldOp = residualGenericOpBody->getTerminator();
315 {
316 // Yield all the result of the peeled scalar operation.
317 OpBuilder::InsertionGuard g(rewriter);
318 rewriter.setInsertionPointToEnd(peeledGenericOpBody);
319 SmallVector<Value> yieldedVals;
320 for (auto origYield : yieldOp->getOperands()) {
321 if (origYield.getDefiningOp() == peeledScalarOperation) {
322 yieldedVals.push_back(origYield);
323 } else {
324 yieldedVals.push_back(
325 getZero(rewriter, genericOp.getLoc(), origYield.getType()));
326 }
327 }
328 yieldedVals.append(llvm::to_vector(
329 llvm::map_range(peeledScalarOperation->getResults(),
330 [](OpResult opr) -> Value { return opr; })));
331 rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals);
332 }
333
334 /// In the split operations, replace block arguments uses that refer to
335 /// original operation to the block arguments of the newly created operation.
336 unsigned origNumInputs = genericOp.getNumInputs();
337 for (const auto &inputBlockArg :
338 llvm::enumerate(genericOp.getBody()->getArguments())) {
339 Value residualOpReplacementArg =
340 residualGenericOpBody->getArgument(inputBlockArg.index());
341 inputBlockArg.value().replaceUsesWithIf(
342 residualOpReplacementArg, [&](OpOperand &use) {
343 return use.getOwner()->getBlock() == residualGenericOpBody;
344 });
345
346 Value peeledOpReplacementArg =
347 peeledGenericOpBody->getArgument(inputBlockArg.index());
348 inputBlockArg.value().replaceUsesWithIf(
349 peeledOpReplacementArg, [&](OpOperand &use) {
350 return use.getOwner()->getBlock() == peeledGenericOpBody;
351 });
352 }
353
354 /// Before fixing up the residual operation, track what values are yielded. If
355 /// any of those are from the peeled scalar operation, the uses of the
356 /// corresponding result have to be remapped to result of the generic op for
357 /// the peeled operation.
358 SmallVector<Value> replacements;
359 for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
360 OpResult opr = yieldValue.value().dyn_cast<OpResult>();
361 if (!opr || opr.getOwner() != peeledScalarOperation)
362 replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
363 else
364 replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
365 }
366
367 /// Update all uses of the peeled scalar operation results in the residual op
368 /// to the newly added arguments.
369 {
370 SmallVector<Value> scalarReplacements;
371 unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
372 scalarReplacements.reserve(peeledScalarOpNumResults);
373 for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
374 scalarReplacements.push_back(
375 residualGenericOpBody->getArgument(num + origNumInputs));
376 bool allUsesReplaced = false;
377 rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
378 residualGenericOpBody, &allUsesReplaced);
379 assert(!allUsesReplaced &&
380 "peeled scalar operation is erased when it wasnt expected to be");
381 }
382
383 // Replace the original operation
384 rewriter.replaceOp(genericOp, replacements);
385 return success();
386 }
387
populateDecomposeLinalgOpsPattern(RewritePatternSet & patterns)388 void mlir::linalg::populateDecomposeLinalgOpsPattern(
389 RewritePatternSet &patterns) {
390 patterns.insert<DecomposeLinalgOp>(patterns.getContext());
391 }
392