1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements linalg transformation to break a reduction dimension
10 // between a parallel and a reduction dimension.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include <utility>
15
16 #include "mlir/Analysis/SliceAnalysis.h"
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Dialect/Tensor/Utils/Utils.h"
24 #include "mlir/IR/PatternMatch.h"
25
26 using namespace mlir;
27 using namespace mlir::linalg;
28
29 /// Return the identity numeric value associated to the give op.
getNeutralElement(Operation * op)30 static Attribute getNeutralElement(Operation *op) {
31 // Builder only used as helper for attribute creation.
32 OpBuilder b(op->getContext());
33 Type resultType = op->getResult(0).getType();
34 if (auto floatType = resultType.dyn_cast<FloatType>()) {
35 const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
36 if (isa<arith::AddFOp>(op))
37 return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
38 if (isa<arith::MulFOp>(op))
39 return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
40 if (isa<arith::MaxFOp>(op))
41 return b.getFloatAttr(resultType,
42 llvm::APFloat::getLargest(semantic, true));
43 if (isa<arith::MinFOp>(op))
44 return b.getFloatAttr(resultType,
45 llvm::APFloat::getLargest(semantic, true));
46 return Attribute();
47 }
48 if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
49 return b.getIntegerAttr(resultType, 0);
50 if (isa<arith::AndIOp>(op))
51 return b.getIntegerAttr(resultType, -1);
52 if (isa<arith::MaxSIOp>(op))
53 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
54 if (isa<arith::MinSIOp>(op))
55 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
56 if (isa<arith::MulIOp>(op))
57 return b.getIntegerAttr(resultType, 1);
58 return Attribute();
59 }
60
splitReduction(PatternRewriter & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,const LinalgTransformationFilter & filter,bool useAlloc)61 FailureOr<LinalgOp> mlir::linalg::splitReduction(
62 PatternRewriter &b, LinalgOp op,
63 const ControlSplitReductionFn &controlSplitReductionFn,
64 const LinalgTransformationFilter &filter, bool useAlloc) {
65 if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
66 op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
67 !op.hasOnlyProjectedPermutations())
68 return b.notifyMatchFailure(op, "precondition not met");
69
70 FailureOr<SplitReductionResult> res =
71 splitReduction(b, op, controlSplitReductionFn, useAlloc);
72 if (failed(res))
73 return failure();
74
75 filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp);
76 filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp);
77
78 return res->splitLinalgOp;
79 }
80
splitReduction(PatternRewriter & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)81 FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
82 PatternRewriter &b, LinalgOp op,
83 const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
84 OpBuilder::InsertionGuard guard(b);
85 b.setInsertionPoint(op);
86
87 std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
88 int64_t ratio = control.first;
89 unsigned insertSplitDimension = control.second;
90 if (ratio <= 1)
91 return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
92
93 SmallVector<unsigned> dims;
94 op.getReductionDims(dims);
95 assert(dims.size() == 1);
96 unsigned reductionDim = dims[0];
97 SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
98 int64_t reductionDimSize = loopRanges[reductionDim];
99 if (reductionDimSize == ShapedType::kDynamicSize ||
100 reductionDimSize % ratio != 0 ||
101 insertSplitDimension >= loopRanges.size())
102 return b.notifyMatchFailure(
103 op, "Reduction dimension not divisible by split ratio");
104
105 SmallVector<Operation *, 4> combinerOps;
106 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
107 combinerOps.size() != 1)
108 return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
109
110 Operation *reductionOp = combinerOps[0];
111 Attribute identity = getNeutralElement(reductionOp);
112 if (!identity)
113 return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
114
115 Location loc = op->getLoc();
116 SmallVector<Value> newInputs;
117 SmallVector<AffineMap> newMaps;
118 // Calculate the new shapes and indexing maps of the input operands.
119 for (OpOperand *operand : op.getInputOperands()) {
120 AffineMap map = op.getTiedIndexingMap(operand);
121 SmallVector<int64_t> newShape;
122 SmallVector<AffineExpr> exprs;
123 SmallVector<ReassociationIndices> reassociation;
124 unsigned index = 0;
125 for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
126 unsigned dim = map.getDimPosition(idx);
127 if (reductionDim == dim) {
128 newShape.push_back(ratio);
129 newShape.push_back(op.getShape(operand)[idx] / ratio);
130 reassociation.push_back({index++, index++});
131 exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
132 exprs.push_back(
133 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
134 continue;
135 }
136 newShape.push_back(op.getShape(operand)[idx]);
137 exprs.push_back(
138 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
139 reassociation.push_back({index++});
140 }
141 newMaps.push_back(
142 AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
143 // If the shape is unchanged the input doesn't change.
144 if (newShape == op.getShape(operand)) {
145 newInputs.push_back(operand->get());
146 continue;
147 }
148 Type newType = RankedTensorType::get(
149 newShape,
150 operand->get().getType().cast<RankedTensorType>().getElementType());
151 Value newInput = b.create<tensor::ExpandShapeOp>(
152 loc, newType, operand->get(), reassociation);
153 newInputs.push_back(newInput);
154 }
155
156 // Calculate the new output map and shape, we insert the new dimension based
157 // on the index returned by `controlSplitReductionFn`.
158 SmallVector<int64_t> newOutputShape;
159 AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0));
160 ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
161 SmallVector<AffineExpr> outputExpr;
162 for (unsigned idx :
163 llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
164 if (idx == insertSplitDimension) {
165 newOutputShape.push_back(ratio);
166 outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
167 continue;
168 }
169 unsigned oldDim = idx < insertSplitDimension ? idx : idx - 1;
170 newOutputShape.push_back(oldShape[oldDim]);
171 unsigned dim = oldOutputMap.getDimPosition(oldDim);
172 outputExpr.push_back(
173 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
174 }
175 Value initOrAllocTensor;
176 if (useAlloc) {
177 initOrAllocTensor = b.create<bufferization::AllocTensorOp>(
178 loc,
179 RankedTensorType::get(newOutputShape,
180 op.getRegionOutputArgs()[0].getType()),
181 ValueRange{});
182 } else {
183 initOrAllocTensor = b.create<linalg::InitTensorOp>(
184 loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
185 }
186 Value constantOp = b.create<arith::ConstantOp>(loc, identity);
187 Value identityTensor =
188 b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor)
189 .getResult(0);
190
191 newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
192 op.getContext()));
193 SmallVector<StringRef> newIteratorTypes;
194 for (auto &it : llvm::enumerate(op.iterator_types())) {
195 if (insertSplitDimension == it.index())
196 newIteratorTypes.push_back(getParallelIteratorTypeName());
197 newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
198 }
199 // Create the new op matching the original op with an extra parallel
200 // dimension.
201 GenericOp genericOp = b.create<GenericOp>(
202 loc, TypeRange({initOrAllocTensor.getType()}), newInputs,
203 ValueRange({identityTensor}), newMaps, newIteratorTypes);
204 b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
205 genericOp.region().begin());
206
207 // Then create a new reduction that only reduce the newly added dimension
208 // from the previous op.
209 unsigned intermRank = newOutputShape.size();
210 AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
211 SmallVector<Value> outputOperands = op.getOutputOperands();
212 SmallVector<StringRef> reductionIteratorTypes;
213 SmallVector<AffineExpr> exprs;
214 for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
215 if (insertSplitDimension == i) {
216 reductionIteratorTypes.push_back(getReductionIteratorTypeName());
217 } else {
218 exprs.push_back(b.getAffineDimExpr(i));
219 reductionIteratorTypes.push_back(getParallelIteratorTypeName());
220 }
221 }
222 AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
223 SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
224
225 auto reduction = b.create<GenericOp>(
226 loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
227 outputOperands, reductionMaps, reductionIteratorTypes,
228 [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
229 Operation *clonedReductionOp = b.clone(*reductionOp);
230 clonedReductionOp->setOperand(0, inputs[0]);
231 clonedReductionOp->setOperand(1, inputs[1]);
232 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
233 });
234 b.replaceOp(op, reduction.getResults());
235
236 return SplitReductionResult{
237 initOrAllocTensor.getDefiningOp(), identityTensor.getDefiningOp<FillOp>(),
238 cast<LinalgOp>(genericOp.getOperation()), reduction};
239 }
240
241 /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
242 /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
243 /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
244 /// done as a transform to enable better vectorization.
scaleReductionDim(LinalgOp op,OpOperand & opOperand,unsigned reductionDimPos,int64_t reductionRatio)245 static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
246 unsigned reductionDimPos,
247 int64_t reductionRatio) {
248 auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
249 auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
250 AffineMap map = op.getTiedIndexingMap(&opOperand);
251 AffineMap idMap =
252 AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
253 AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
254 AffineMap composeMap = shiftedIdMap.replace(
255 reductionDim, reductionDim * reductionRatio + reductionDimP1,
256 shiftedIdMap.getNumDims(), /*numSymbols=*/0);
257 return map.compose(composeMap);
258 }
259
insertParallelDim(LinalgOp op,OpOperand & opOperand,unsigned reductionDimPos,int64_t size)260 static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
261 unsigned reductionDimPos, int64_t size) {
262 auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
263 AffineMap map = op.getTiedIndexingMap(&opOperand);
264 AffineMap idMap =
265 AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
266 AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
267 return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
268 }
269
270 /// Core rewrite implementation.
splitReductionByScaling(PatternRewriter & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)271 FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
272 PatternRewriter &b, LinalgOp op,
273 const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
274 OpBuilder::InsertionGuard guard(b);
275 b.setInsertionPoint(op);
276
277 // Matcher part, enforce preconditions.
278 std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
279 int64_t splitFactor = control.first;
280 unsigned insertSplitDimension = control.second;
281 if (splitFactor <= 1)
282 return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
283
284 SmallVector<unsigned> dims;
285 op.getReductionDims(dims);
286 if (dims.empty())
287 return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
288
289 unsigned reductionDimPos = dims[0];
290 SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
291 int64_t reductionDimSize = loopRanges[reductionDimPos];
292 if (reductionDimSize == ShapedType::kDynamicSize ||
293 reductionDimSize % splitFactor != 0 ||
294 insertSplitDimension >= loopRanges.size())
295 return b.notifyMatchFailure(
296 op, "first reduction dimension not divisible by split factor");
297
298 SmallVector<Operation *> combinerOps;
299 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
300 return b.notifyMatchFailure(op, "cannot match a reduction pattern");
301
302 SmallVector<Attribute> neutralElements = llvm::to_vector<4>(
303 llvm::map_range(combinerOps, [&](Operation *reductionOp) {
304 return getNeutralElement(reductionOp);
305 }));
306 if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
307 return b.notifyMatchFailure(op, "unknown reduction neutral");
308
309 // TODO: relax this when multi-reduction support is available.
310 if (op.getNumOutputs() != static_cast<int64_t>(neutralElements.size()))
311 return b.notifyMatchFailure(op, "expect one reduction per output");
312
313 // Rewrite part.
314 // Step 1. Build the intermediate outputs filled with the proper
315 // neutralElements. Such outputs are of the same shape with an extra dimension
316 // inserted at `insertSplitDimension`.
317 //
318 // Consider a minimal example where `k` is reduced:
319 // O(i, j) += I(i, j, k)
320 // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
321 // The compute is rewritten as:
322 // a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
323 // b. O(i, j) += O_i(kk, i, j)
324 // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
325 Location loc = op->getLoc();
326 MLIRContext *context = op.getContext();
327 // For now assume outputs are 1-1 with reduction neutralElements.
328 // TODO: generalize when multi-reduction support is available.
329 SmallVector<Value> newOutputs;
330 newOutputs.reserve(op.getNumOutputs());
331 SmallVector<Operation *> initOrAllocTensorOps;
332 SmallVector<linalg::FillOp> fillOps;
333 fillOps.reserve(op.getNumOutputs());
334 for (auto it : llvm::zip(op.outputs(), neutralElements)) {
335 Value rankedTensor = std::get<0>(it);
336 auto t = rankedTensor.getType().cast<RankedTensorType>();
337 RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
338 reductionDimSize / splitFactor, insertSplitDimension);
339 SmallVector<Value> dims =
340 tensor::createDynamicDimValues(b, loc, rankedTensor);
341 Value initOrAllocTensor;
342 if (useAlloc) {
343 initOrAllocTensor =
344 b.create<bufferization::AllocTensorOp>(loc, newT, dims);
345 } else {
346 initOrAllocTensor = b.create<linalg::InitTensorOp>(
347 loc, dims, newT.getShape(), t.getElementType());
348 }
349 Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
350 fillOps.push_back(
351 b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor));
352 newOutputs.push_back(fillOps.back().getResult(0));
353 initOrAllocTensorOps.push_back(initOrAllocTensor.getDefiningOp());
354 }
355
356 // Step 2. Reindex / expand indexing maps.
357 // Reindex existing input indexings: k -> k * splitFactor + k'.
358 SmallVector<AffineMap> newMaps;
359 newMaps.reserve(op.getNumInputsAndOutputs() + 1);
360 for (OpOperand *o : op.getInputOperands())
361 newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
362 // Provision a new indexing for the shape-only tensor.
363 auto nDims = op.getNumLoops() + 1;
364 auto redDim = getAffineDimExpr(reductionDimPos, context);
365 auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
366 newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
367 // Expand existing output indexings.
368 // TODO: a subset of these may not reduce along reducePos and should be
369 // reindexed: k -> k * splitFactor + k', when multi-reduction support is
370 // available.
371 for (OpOperand *o : op.getOutputOperands())
372 newMaps.push_back(insertParallelDim(op, *o, reductionDimPos,
373 reductionDimSize / splitFactor));
374
375 // Step 3. Handle operands.
376 // Compute the new input tensors.
377 auto newInputs = llvm::to_vector<4>(op.inputs());
378 // Add a single shape-only tensor to carry the dimensions without resorting to
379 // more complex inversions.
380 newInputs.push_back(b.create<linalg::InitTensorOp>(
381 loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
382 b.getIntegerType(1)));
383 // Output tensors are already good to go.
384
385 // Step 4. Create the new op matching the original op with an extra parallel
386 // dimension.
387 SmallVector<StringRef> iteratorTypes =
388 llvm::to_vector<4>(op.getIteratorTypes().getAsValueRange<StringAttr>());
389 iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
390 getParallelIteratorTypeName());
391 GenericOp genericOp =
392 b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
393 newOutputs, newMaps, iteratorTypes);
394 b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
395 genericOp.region().begin());
396 genericOp.region().front().insertArgument(reductionDimPos,
397 b.getIntegerType(1), loc);
398
399 // Step 5. Create new reduction ops that only reduce the newly added
400 // dimensions from the previous op.
401 // For now assume outputs are 1-1 with reduction ops.
402 // TODO: a subset of these may not reduce in the first place and do not
403 // require a new op, when multi-reduction support is available.
404 // TODO: all results can be handled in a single GenericOp, when
405 // multi-reduction support is available.
406 SmallVector<LinalgOp> results;
407 for (auto it :
408 llvm::zip(genericOp->getResults(), op.outputs(), combinerOps)) {
409 Value reindexedOutput = std::get<0>(it);
410 Value originalOutput = std::get<1>(it);
411 auto originalOutputType = originalOutput.getType().cast<RankedTensorType>();
412 Operation *combinerOp = std::get<2>(it);
413
414 AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
415 SmallVector<AffineMap> indexingMaps = {
416 map, map.dropResult(insertSplitDimension)};
417 SmallVector<StringRef> reductionIteratorTypes(
418 originalOutputType.getRank() + 1, getParallelIteratorTypeName());
419 reductionIteratorTypes[insertSplitDimension] =
420 getReductionIteratorTypeName();
421
422 // clang-format off
423 auto reductionOp = b.create<GenericOp>(
424 loc,
425 originalOutputType,
426 reindexedOutput,
427 originalOutput,
428 indexingMaps,
429 reductionIteratorTypes,
430 [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
431 Operation *clonedReductionOp = b.clone(*combinerOp);
432 clonedReductionOp->setOperand(0, bbArgs[0]);
433 clonedReductionOp->setOperand(1, bbArgs[1]);
434 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
435 });
436 // clang-format on
437
438 results.push_back(reductionOp);
439 }
440
441 // TODO: extend when multi-reduction support is available.
442 assert(fillOps.size() == results.size() && results.size() == 1);
443 b.replaceOp(op, results.front()->getResults());
444 return SplitReductionResult{initOrAllocTensorOps.front(), fillOps.front(),
445 cast<LinalgOp>(genericOp.getOperation()),
446 results.front()};
447 }
448
449 namespace {
450
451 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
452 /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgSplitReduction__anon86cbc2bb0511::LinalgSplitReduction453 LinalgSplitReduction(MLIRContext *context,
454 ControlSplitReductionFn controlSplitReductionFn,
455 LinalgTransformationFilter f, bool useAlloc = false,
456 PatternBenefit benefit = 1)
457 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
458 controlSplitReductionFn(std::move(controlSplitReductionFn)),
459 useAlloc(useAlloc), filter(std::move(f)) {}
460
matchAndRewrite__anon86cbc2bb0511::LinalgSplitReduction461 LogicalResult matchAndRewrite(LinalgOp op,
462 PatternRewriter &rewriter) const override {
463 return splitReduction(rewriter, op, controlSplitReductionFn, filter,
464 useAlloc);
465 }
466
467 private:
468 ControlSplitReductionFn controlSplitReductionFn;
469 bool useAlloc;
470 LinalgTransformationFilter filter;
471 };
472
473 } // namespace
474
populateSplitReductionPattern(RewritePatternSet & patterns,const ControlSplitReductionFn & controlSplitReductionFn,const LinalgTransformationFilter & f,bool useAlloc)475 void linalg::populateSplitReductionPattern(
476 RewritePatternSet &patterns,
477 const ControlSplitReductionFn &controlSplitReductionFn,
478 const LinalgTransformationFilter &f, bool useAlloc) {
479 patterns.add<LinalgSplitReduction>(patterns.getContext(),
480 controlSplitReductionFn, f, useAlloc);
481 }
482