1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements linalg transformation to break a reduction dimension 10 // between a parallel and a reduction dimension. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <utility> 15 16 #include "mlir/Analysis/SliceAnalysis.h" 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Dialect/Tensor/Utils/Utils.h" 23 #include "mlir/IR/PatternMatch.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 28 /// Return the identity numeric value associated to the give op. 29 static Attribute getNeutralElement(Operation *op) { 30 // Builder only used as helper for attribute creation. 31 OpBuilder b(op->getContext()); 32 Type resultType = op->getResult(0).getType(); 33 if (auto floatType = resultType.dyn_cast<FloatType>()) { 34 const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); 35 if (isa<arith::AddFOp>(op)) 36 return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); 37 if (isa<arith::MulFOp>(op)) 38 return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); 39 if (isa<arith::MaxFOp>(op)) 40 return b.getFloatAttr(resultType, 41 llvm::APFloat::getLargest(semantic, true)); 42 if (isa<arith::MinFOp>(op)) 43 return b.getFloatAttr(resultType, 44 llvm::APFloat::getLargest(semantic, true)); 45 return Attribute(); 46 } 47 if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op)) 48 return b.getIntegerAttr(resultType, 0); 49 if (isa<arith::AndIOp>(op)) 50 return b.getIntegerAttr(resultType, -1); 51 if (isa<arith::MaxSIOp>(op)) 52 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min()); 53 if (isa<arith::MinSIOp>(op)) 54 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max()); 55 if (isa<arith::MulIOp>(op)) 56 return b.getIntegerAttr(resultType, 1); 57 return Attribute(); 58 } 59 60 FailureOr<LinalgOp> mlir::linalg::splitReduction( 61 PatternRewriter &b, LinalgOp op, 62 const ControlSplitReductionFn &controlSplitReductionFn, 63 const LinalgTransformationFilter &filter) { 64 if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() || 65 op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 || 66 !op.hasOnlyProjectedPermutations()) 67 return b.notifyMatchFailure(op, "precondition not met"); 68 69 FailureOr<SplitReductionResult> res = 70 splitReduction(b, op, controlSplitReductionFn); 71 if (failed(res)) 72 return failure(); 73 74 filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp); 75 filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp); 76 77 return res->splitLinalgOp; 78 } 79 80 FailureOr<SplitReductionResult> mlir::linalg::splitReduction( 81 PatternRewriter &b, LinalgOp op, 82 const ControlSplitReductionFn &controlSplitReductionFn) { 83 OpBuilder::InsertionGuard guard(b); 84 b.setInsertionPoint(op); 85 86 std::pair<int64_t, unsigned> control = controlSplitReductionFn(op); 87 int64_t ratio = control.first; 88 unsigned insertSplitDimension = control.second; 89 if (ratio <= 1) 90 return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); 91 92 SmallVector<unsigned> dims; 93 op.getReductionDims(dims); 94 assert(dims.size() == 1); 95 unsigned reductionDim = dims[0]; 96 SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges(); 97 int64_t reductionDimSize = loopRanges[reductionDim]; 98 if (reductionDimSize == ShapedType::kDynamicSize || 99 reductionDimSize % ratio != 0 || 100 insertSplitDimension >= loopRanges.size()) 101 return b.notifyMatchFailure( 102 op, "Reduction dimension not divisible by split ratio"); 103 104 SmallVector<Operation *, 4> combinerOps; 105 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || 106 combinerOps.size() != 1) 107 return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); 108 109 Operation *reductionOp = combinerOps[0]; 110 Attribute identity = getNeutralElement(reductionOp); 111 if (!identity) 112 return b.notifyMatchFailure(op, "Unknown identity value for the reduction"); 113 114 Location loc = op->getLoc(); 115 SmallVector<Value> newInputs; 116 SmallVector<AffineMap> newMaps; 117 // Calculate the new shapes and indexing maps of the input operands. 118 for (OpOperand *operand : op.getInputOperands()) { 119 AffineMap map = op.getTiedIndexingMap(operand); 120 SmallVector<int64_t> newShape; 121 SmallVector<AffineExpr> exprs; 122 SmallVector<ReassociationIndices> reassociation; 123 unsigned index = 0; 124 for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) { 125 unsigned dim = map.getDimPosition(idx); 126 if (reductionDim == dim) { 127 newShape.push_back(ratio); 128 newShape.push_back(op.getShape(operand)[idx] / ratio); 129 reassociation.push_back({index++, index++}); 130 exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); 131 exprs.push_back( 132 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); 133 continue; 134 } 135 newShape.push_back(op.getShape(operand)[idx]); 136 exprs.push_back( 137 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); 138 reassociation.push_back({index++}); 139 } 140 newMaps.push_back( 141 AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext())); 142 // If the shape is unchanged the input doesn't change. 143 if (newShape == op.getShape(operand)) { 144 newInputs.push_back(operand->get()); 145 continue; 146 } 147 Type newType = RankedTensorType::get( 148 newShape, 149 operand->get().getType().cast<RankedTensorType>().getElementType()); 150 Value newInput = b.create<tensor::ExpandShapeOp>( 151 loc, newType, operand->get(), reassociation); 152 newInputs.push_back(newInput); 153 } 154 155 // Calculate the new output map and shape, we insert the new dimension based 156 // on the index returned by `controlSplitReductionFn`. 157 SmallVector<int64_t> newOutputShape; 158 AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0)); 159 ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0)); 160 SmallVector<AffineExpr> outputExpr; 161 for (unsigned idx : 162 llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) { 163 if (idx == insertSplitDimension) { 164 newOutputShape.push_back(ratio); 165 outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension)); 166 continue; 167 } 168 unsigned oldDim = idx < insertSplitDimension ? idx : idx - 1; 169 newOutputShape.push_back(oldShape[oldDim]); 170 unsigned dim = oldOutputMap.getDimPosition(oldDim); 171 outputExpr.push_back( 172 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); 173 } 174 Value initTensor = b.create<linalg::InitTensorOp>( 175 loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); 176 Value constantOp = b.create<arith::ConstantOp>(loc, identity); 177 Value identityTensor = 178 b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor) 179 .getResult(0); 180 181 newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, 182 op.getContext())); 183 SmallVector<StringRef> newIteratorTypes; 184 for (auto &it : llvm::enumerate(op.iterator_types())) { 185 if (insertSplitDimension == it.index()) 186 newIteratorTypes.push_back(getParallelIteratorTypeName()); 187 newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue()); 188 } 189 // Create the new op matching the original op with an extra parallel 190 // dimension. 191 GenericOp genericOp = b.create<GenericOp>( 192 loc, TypeRange({initTensor.getType()}), newInputs, 193 ValueRange({identityTensor}), newMaps, newIteratorTypes); 194 b.inlineRegionBefore(op->getRegion(0), genericOp.region(), 195 genericOp.region().begin()); 196 197 // Then create a new reduction that only reduce the newly added dimension 198 // from the previous op. 199 unsigned intermRank = newOutputShape.size(); 200 AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); 201 SmallVector<Value> outputOperands = op.getOutputOperands(); 202 SmallVector<StringRef> reductionIteratorTypes; 203 SmallVector<AffineExpr> exprs; 204 for (unsigned i : llvm::seq<unsigned>(0, intermRank)) { 205 if (insertSplitDimension == i) { 206 reductionIteratorTypes.push_back(getReductionIteratorTypeName()); 207 } else { 208 exprs.push_back(b.getAffineDimExpr(i)); 209 reductionIteratorTypes.push_back(getParallelIteratorTypeName()); 210 } 211 } 212 AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); 213 SmallVector<AffineMap> reductionMaps = {inputMap, outputMap}; 214 215 auto reduction = b.create<GenericOp>( 216 loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), 217 outputOperands, reductionMaps, reductionIteratorTypes, 218 [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { 219 Operation *clonedReductionOp = b.clone(*reductionOp); 220 clonedReductionOp->setOperand(0, inputs[0]); 221 clonedReductionOp->setOperand(1, inputs[1]); 222 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); 223 }); 224 b.replaceOp(op, reduction.getResults()); 225 226 return SplitReductionResult{identityTensor.getDefiningOp<FillOp>(), 227 cast<LinalgOp>(genericOp.getOperation()), 228 reduction}; 229 } 230 231 /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...) 232 /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into 233 /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better 234 /// done as a transform to enable better vectorization. 235 static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand, 236 unsigned reductionDimPos, 237 int64_t reductionRatio) { 238 auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); 239 auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext()); 240 AffineMap map = op.getTiedIndexingMap(&opOperand); 241 AffineMap idMap = 242 AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext()); 243 AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1); 244 AffineMap composeMap = shiftedIdMap.replace( 245 reductionDim, reductionDim * reductionRatio + reductionDimP1, 246 shiftedIdMap.getNumDims(), /*numSymbols=*/0); 247 return map.compose(composeMap); 248 } 249 250 static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand, 251 unsigned reductionDimPos, int64_t size) { 252 auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); 253 AffineMap map = op.getTiedIndexingMap(&opOperand); 254 AffineMap idMap = 255 AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext()); 256 AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1); 257 return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos); 258 } 259 260 /// Core rewrite implementation. 261 FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling( 262 PatternRewriter &b, LinalgOp op, 263 const ControlSplitReductionFn &controlSplitReductionFn) { 264 OpBuilder::InsertionGuard guard(b); 265 b.setInsertionPoint(op); 266 267 // Matcher part, enforce preconditions. 268 std::pair<int64_t, unsigned> control = controlSplitReductionFn(op); 269 int64_t splitFactor = control.first; 270 unsigned insertSplitDimension = control.second; 271 if (splitFactor <= 1) 272 return b.notifyMatchFailure(op, "split factor needs to be greater than 1"); 273 274 SmallVector<unsigned> dims; 275 op.getReductionDims(dims); 276 if (dims.empty()) 277 return b.notifyMatchFailure(op, "needs at least 1 reduction dimension"); 278 279 unsigned reductionDimPos = dims[0]; 280 SmallVector<int64_t> loopRanges = op.getStaticLoopRanges(); 281 int64_t reductionDimSize = loopRanges[reductionDimPos]; 282 if (reductionDimSize == ShapedType::kDynamicSize || 283 reductionDimSize % splitFactor != 0 || 284 insertSplitDimension >= loopRanges.size()) 285 return b.notifyMatchFailure( 286 op, "first reduction dimension not divisible by split factor"); 287 288 SmallVector<Operation *> combinerOps; 289 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps)) 290 return b.notifyMatchFailure(op, "cannot match a reduction pattern"); 291 292 SmallVector<Attribute> neutralElements = llvm::to_vector<4>( 293 llvm::map_range(combinerOps, [&](Operation *reductionOp) { 294 return getNeutralElement(reductionOp); 295 })); 296 if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; })) 297 return b.notifyMatchFailure(op, "unknown reduction neutral"); 298 299 // TODO: relax this when multi-reduction support is available. 300 if (op.getNumOutputs() != (int)neutralElements.size()) 301 return b.notifyMatchFailure(op, "expect one reduction per output"); 302 303 // Rewrite part. 304 // Step 1. Build the intermediate outputs filled with the proper 305 // neutralElements. Such outputs are of the same shape with an extra dimension 306 // inserted at `insertSplitDimension`. 307 // 308 // Consider a minimal example where `k` is reduced: 309 // O(i, j) += I(i, j, k) 310 // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0. 311 // The compute is rewritten as: 312 // a. O_i(kk, i, j) += I(i, j, 16 * k + kk) 313 // b. O(i, j) += O_i(kk, i, j) 314 // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5. 315 Location loc = op->getLoc(); 316 MLIRContext *context = op.getContext(); 317 // For now assume outputs are 1-1 with reduction neutralElements. 318 // TODO: generalize when multi-reduction support is available. 319 SmallVector<Value> newOutputs; 320 newOutputs.reserve(op.getNumOutputs()); 321 SmallVector<linalg::FillOp> fillOps; 322 fillOps.reserve(op.getNumOutputs()); 323 for (auto it : llvm::zip(op.outputs(), neutralElements)) { 324 Value rankedTensor = std::get<0>(it); 325 auto t = rankedTensor.getType().cast<RankedTensorType>(); 326 RankedTensorType newT = RankedTensorType::Builder(t).insertDim( 327 reductionDimSize / splitFactor, insertSplitDimension); 328 SmallVector<Value> dims = 329 tensor::createDynamicDimValues(b, loc, rankedTensor); 330 Value initTensor = b.create<linalg::InitTensorOp>( 331 loc, dims, newT.getShape(), t.getElementType()); 332 Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it)); 333 fillOps.push_back( 334 b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)); 335 newOutputs.push_back(fillOps.back().getResult(0)); 336 } 337 338 // Step 2. Reindex / expand indexing maps. 339 // Reindex existing input indexings: k -> k * splitFactor + k'. 340 SmallVector<AffineMap> newMaps; 341 newMaps.reserve(op.getNumInputsAndOutputs() + 1); 342 for (OpOperand *o : op.getInputOperands()) 343 newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor)); 344 // Provision a new indexing for the shape-only tensor. 345 auto nDims = op.getNumLoops() + 1; 346 auto redDim = getAffineDimExpr(reductionDimPos, context); 347 auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context); 348 newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context)); 349 // Expand existing output indexings. 350 // TODO: a subset of these may not reduce along reducePos and should be 351 // reindexed: k -> k * splitFactor + k', when multi-reduction support is 352 // available. 353 for (OpOperand *o : op.getOutputOperands()) 354 newMaps.push_back(insertParallelDim(op, *o, reductionDimPos, 355 reductionDimSize / splitFactor)); 356 357 // Step 3. Handle operands. 358 // Compute the new input tensors. 359 auto newInputs = llvm::to_vector<4>(op.inputs()); 360 // Add a single shape-only tensor to carry the dimensions without resorting to 361 // more complex inversions. 362 newInputs.push_back(b.create<linalg::InitTensorOp>( 363 loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor}, 364 b.getIntegerType(1))); 365 // Output tensors are already good to go. 366 367 // Step 4. Create the new op matching the original op with an extra parallel 368 // dimension. 369 SmallVector<StringRef> iteratorTypes = 370 llvm::to_vector<4>(op.getIteratorTypes().getAsValueRange<StringAttr>()); 371 iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos, 372 getParallelIteratorTypeName()); 373 GenericOp genericOp = 374 b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs, 375 newOutputs, newMaps, iteratorTypes); 376 b.inlineRegionBefore(op->getRegion(0), genericOp.region(), 377 genericOp.region().begin()); 378 genericOp.region().front().insertArgument(reductionDimPos, 379 b.getIntegerType(1), loc); 380 381 // Step 5. Create new reduction ops that only reduce the newly added 382 // dimensions from the previous op. 383 // For now assume outputs are 1-1 with reduction ops. 384 // TODO: a subset of these may not reduce in the first place and do not 385 // require a new op, when multi-reduction support is available. 386 // TODO: all results can be handled in a single GenericOp, when 387 // multi-reduction support is available. 388 SmallVector<LinalgOp> results; 389 for (auto it : 390 llvm::zip(genericOp->getResults(), op.outputs(), combinerOps)) { 391 Value reindexedOutput = std::get<0>(it); 392 Value originalOutput = std::get<1>(it); 393 auto originalOutputType = originalOutput.getType().cast<RankedTensorType>(); 394 Operation *combinerOp = std::get<2>(it); 395 396 AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1); 397 SmallVector<AffineMap> indexingMaps = { 398 map, map.dropResult(insertSplitDimension)}; 399 SmallVector<StringRef> reductionIteratorTypes( 400 originalOutputType.getRank() + 1, getParallelIteratorTypeName()); 401 reductionIteratorTypes[insertSplitDimension] = 402 getReductionIteratorTypeName(); 403 404 // clang-format off 405 auto reductionOp = b.create<GenericOp>( 406 loc, 407 originalOutputType, 408 reindexedOutput, 409 originalOutput, 410 indexingMaps, 411 reductionIteratorTypes, 412 [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) { 413 Operation *clonedReductionOp = b.clone(*combinerOp); 414 clonedReductionOp->setOperand(0, bbArgs[0]); 415 clonedReductionOp->setOperand(1, bbArgs[1]); 416 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); 417 }); 418 // clang-format on 419 420 results.push_back(reductionOp); 421 } 422 423 // TODO: extend when multi-reduction support is available. 424 assert(fillOps.size() == results.size() && results.size() == 1); 425 b.replaceOp(op, results.front()->getResults()); 426 return SplitReductionResult{fillOps.front(), 427 cast<LinalgOp>(genericOp.getOperation()), 428 results.front()}; 429 } 430 431 namespace { 432 433 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> { 434 /// Construct a generic pattern applied to all LinalgOp that verify `filter`. 435 LinalgSplitReduction(MLIRContext *context, 436 ControlSplitReductionFn controlSplitReductionFn, 437 LinalgTransformationFilter f, PatternBenefit benefit = 1) 438 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 439 controlSplitReductionFn(std::move(controlSplitReductionFn)), 440 filter(std::move(f)) {} 441 442 LogicalResult matchAndRewrite(LinalgOp op, 443 PatternRewriter &rewriter) const override { 444 return splitReduction(rewriter, op, controlSplitReductionFn, filter); 445 } 446 447 private: 448 ControlSplitReductionFn controlSplitReductionFn; 449 LinalgTransformationFilter filter; 450 }; 451 452 } // namespace 453 454 void linalg::populateSplitReductionPattern( 455 RewritePatternSet &patterns, 456 const ControlSplitReductionFn &controlSplitReductionFn, 457 const LinalgTransformationFilter &f) { 458 patterns.add<LinalgSplitReduction>(patterns.getContext(), 459 controlSplitReductionFn, f); 460 } 461