1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements linalg transformation to break a reduction dimension 10 // between a parallel and a reduction dimension. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <utility> 15 16 #include "mlir/Analysis/SliceAnalysis.h" 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Tensor/Utils/Utils.h" 24 #include "mlir/IR/PatternMatch.h" 25 26 using namespace mlir; 27 using namespace mlir::linalg; 28 29 /// Return the identity numeric value associated to the give op. 30 static Attribute getNeutralElement(Operation *op) { 31 // Builder only used as helper for attribute creation. 32 OpBuilder b(op->getContext()); 33 Type resultType = op->getResult(0).getType(); 34 if (auto floatType = resultType.dyn_cast<FloatType>()) { 35 const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); 36 if (isa<arith::AddFOp>(op)) 37 return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); 38 if (isa<arith::MulFOp>(op)) 39 return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); 40 if (isa<arith::MaxFOp>(op)) 41 return b.getFloatAttr(resultType, 42 llvm::APFloat::getLargest(semantic, true)); 43 if (isa<arith::MinFOp>(op)) 44 return b.getFloatAttr(resultType, 45 llvm::APFloat::getLargest(semantic, true)); 46 return Attribute(); 47 } 48 if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op)) 49 return b.getIntegerAttr(resultType, 0); 50 if (isa<arith::AndIOp>(op)) 51 return b.getIntegerAttr(resultType, -1); 52 if (isa<arith::MaxSIOp>(op)) 53 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min()); 54 if (isa<arith::MinSIOp>(op)) 55 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max()); 56 if (isa<arith::MulIOp>(op)) 57 return b.getIntegerAttr(resultType, 1); 58 return Attribute(); 59 } 60 61 FailureOr<LinalgOp> mlir::linalg::splitReduction( 62 PatternRewriter &b, LinalgOp op, 63 const ControlSplitReductionFn &controlSplitReductionFn, 64 const LinalgTransformationFilter &filter, bool useAlloc) { 65 if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() || 66 op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 || 67 !op.hasOnlyProjectedPermutations()) 68 return b.notifyMatchFailure(op, "precondition not met"); 69 70 FailureOr<SplitReductionResult> res = 71 splitReduction(b, op, controlSplitReductionFn, useAlloc); 72 if (failed(res)) 73 return failure(); 74 75 filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp); 76 filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp); 77 78 return res->splitLinalgOp; 79 } 80 81 FailureOr<SplitReductionResult> mlir::linalg::splitReduction( 82 PatternRewriter &b, LinalgOp op, 83 const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { 84 OpBuilder::InsertionGuard guard(b); 85 b.setInsertionPoint(op); 86 87 std::pair<int64_t, unsigned> control = controlSplitReductionFn(op); 88 int64_t ratio = control.first; 89 unsigned insertSplitDimension = control.second; 90 if (ratio <= 1) 91 return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); 92 93 SmallVector<unsigned> dims; 94 op.getReductionDims(dims); 95 assert(dims.size() == 1); 96 unsigned reductionDim = dims[0]; 97 SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges(); 98 int64_t reductionDimSize = loopRanges[reductionDim]; 99 if (reductionDimSize == ShapedType::kDynamicSize || 100 reductionDimSize % ratio != 0 || 101 insertSplitDimension >= loopRanges.size()) 102 return b.notifyMatchFailure( 103 op, "Reduction dimension not divisible by split ratio"); 104 105 SmallVector<Operation *, 4> combinerOps; 106 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || 107 combinerOps.size() != 1) 108 return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); 109 110 Operation *reductionOp = combinerOps[0]; 111 Attribute identity = getNeutralElement(reductionOp); 112 if (!identity) 113 return b.notifyMatchFailure(op, "Unknown identity value for the reduction"); 114 115 Location loc = op->getLoc(); 116 SmallVector<Value> newInputs; 117 SmallVector<AffineMap> newMaps; 118 // Calculate the new shapes and indexing maps of the input operands. 119 for (OpOperand *operand : op.getInputOperands()) { 120 AffineMap map = op.getTiedIndexingMap(operand); 121 SmallVector<int64_t> newShape; 122 SmallVector<AffineExpr> exprs; 123 SmallVector<ReassociationIndices> reassociation; 124 unsigned index = 0; 125 for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) { 126 unsigned dim = map.getDimPosition(idx); 127 if (reductionDim == dim) { 128 newShape.push_back(ratio); 129 newShape.push_back(op.getShape(operand)[idx] / ratio); 130 reassociation.push_back({index++, index++}); 131 exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); 132 exprs.push_back( 133 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); 134 continue; 135 } 136 newShape.push_back(op.getShape(operand)[idx]); 137 exprs.push_back( 138 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); 139 reassociation.push_back({index++}); 140 } 141 newMaps.push_back( 142 AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext())); 143 // If the shape is unchanged the input doesn't change. 144 if (newShape == op.getShape(operand)) { 145 newInputs.push_back(operand->get()); 146 continue; 147 } 148 Type newType = RankedTensorType::get( 149 newShape, 150 operand->get().getType().cast<RankedTensorType>().getElementType()); 151 Value newInput = b.create<tensor::ExpandShapeOp>( 152 loc, newType, operand->get(), reassociation); 153 newInputs.push_back(newInput); 154 } 155 156 // Calculate the new output map and shape, we insert the new dimension based 157 // on the index returned by `controlSplitReductionFn`. 158 SmallVector<int64_t> newOutputShape; 159 AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0)); 160 ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0)); 161 SmallVector<AffineExpr> outputExpr; 162 for (unsigned idx : 163 llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) { 164 if (idx == insertSplitDimension) { 165 newOutputShape.push_back(ratio); 166 outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension)); 167 continue; 168 } 169 unsigned oldDim = idx < insertSplitDimension ? idx : idx - 1; 170 newOutputShape.push_back(oldShape[oldDim]); 171 unsigned dim = oldOutputMap.getDimPosition(oldDim); 172 outputExpr.push_back( 173 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); 174 } 175 Value initOrAllocTensor; 176 if (useAlloc) { 177 initOrAllocTensor = b.create<bufferization::AllocTensorOp>( 178 loc, 179 RankedTensorType::get(newOutputShape, 180 op.getRegionOutputArgs()[0].getType()), 181 ValueRange{}); 182 } else { 183 initOrAllocTensor = b.create<linalg::InitTensorOp>( 184 loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); 185 } 186 Value constantOp = b.create<arith::ConstantOp>(loc, identity); 187 Value identityTensor = 188 b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor) 189 .getResult(0); 190 191 newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, 192 op.getContext())); 193 SmallVector<StringRef> newIteratorTypes; 194 for (auto &it : llvm::enumerate(op.iterator_types())) { 195 if (insertSplitDimension == it.index()) 196 newIteratorTypes.push_back(getParallelIteratorTypeName()); 197 newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue()); 198 } 199 // Create the new op matching the original op with an extra parallel 200 // dimension. 201 GenericOp genericOp = b.create<GenericOp>( 202 loc, TypeRange({initOrAllocTensor.getType()}), newInputs, 203 ValueRange({identityTensor}), newMaps, newIteratorTypes); 204 b.inlineRegionBefore(op->getRegion(0), genericOp.region(), 205 genericOp.region().begin()); 206 207 // Then create a new reduction that only reduce the newly added dimension 208 // from the previous op. 209 unsigned intermRank = newOutputShape.size(); 210 AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); 211 SmallVector<Value> outputOperands = op.getOutputOperands(); 212 SmallVector<StringRef> reductionIteratorTypes; 213 SmallVector<AffineExpr> exprs; 214 for (unsigned i : llvm::seq<unsigned>(0, intermRank)) { 215 if (insertSplitDimension == i) { 216 reductionIteratorTypes.push_back(getReductionIteratorTypeName()); 217 } else { 218 exprs.push_back(b.getAffineDimExpr(i)); 219 reductionIteratorTypes.push_back(getParallelIteratorTypeName()); 220 } 221 } 222 AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); 223 SmallVector<AffineMap> reductionMaps = {inputMap, outputMap}; 224 225 auto reduction = b.create<GenericOp>( 226 loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), 227 outputOperands, reductionMaps, reductionIteratorTypes, 228 [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { 229 Operation *clonedReductionOp = b.clone(*reductionOp); 230 clonedReductionOp->setOperand(0, inputs[0]); 231 clonedReductionOp->setOperand(1, inputs[1]); 232 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); 233 }); 234 b.replaceOp(op, reduction.getResults()); 235 236 return SplitReductionResult{ 237 initOrAllocTensor.getDefiningOp(), identityTensor.getDefiningOp<FillOp>(), 238 cast<LinalgOp>(genericOp.getOperation()), reduction}; 239 } 240 241 /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...) 242 /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into 243 /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better 244 /// done as a transform to enable better vectorization. 245 static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand, 246 unsigned reductionDimPos, 247 int64_t reductionRatio) { 248 auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); 249 auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext()); 250 AffineMap map = op.getTiedIndexingMap(&opOperand); 251 AffineMap idMap = 252 AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext()); 253 AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1); 254 AffineMap composeMap = shiftedIdMap.replace( 255 reductionDim, reductionDim * reductionRatio + reductionDimP1, 256 shiftedIdMap.getNumDims(), /*numSymbols=*/0); 257 return map.compose(composeMap); 258 } 259 260 static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand, 261 unsigned reductionDimPos, int64_t size) { 262 auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); 263 AffineMap map = op.getTiedIndexingMap(&opOperand); 264 AffineMap idMap = 265 AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext()); 266 AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1); 267 return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos); 268 } 269 270 /// Core rewrite implementation. 271 FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling( 272 PatternRewriter &b, LinalgOp op, 273 const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { 274 OpBuilder::InsertionGuard guard(b); 275 b.setInsertionPoint(op); 276 277 // Matcher part, enforce preconditions. 278 std::pair<int64_t, unsigned> control = controlSplitReductionFn(op); 279 int64_t splitFactor = control.first; 280 unsigned insertSplitDimension = control.second; 281 if (splitFactor <= 1) 282 return b.notifyMatchFailure(op, "split factor needs to be greater than 1"); 283 284 SmallVector<unsigned> dims; 285 op.getReductionDims(dims); 286 if (dims.empty()) 287 return b.notifyMatchFailure(op, "needs at least 1 reduction dimension"); 288 289 unsigned reductionDimPos = dims[0]; 290 SmallVector<int64_t> loopRanges = op.getStaticLoopRanges(); 291 int64_t reductionDimSize = loopRanges[reductionDimPos]; 292 if (reductionDimSize == ShapedType::kDynamicSize || 293 reductionDimSize % splitFactor != 0 || 294 insertSplitDimension >= loopRanges.size()) 295 return b.notifyMatchFailure( 296 op, "first reduction dimension not divisible by split factor"); 297 298 SmallVector<Operation *> combinerOps; 299 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps)) 300 return b.notifyMatchFailure(op, "cannot match a reduction pattern"); 301 302 SmallVector<Attribute> neutralElements = llvm::to_vector<4>( 303 llvm::map_range(combinerOps, [&](Operation *reductionOp) { 304 return getNeutralElement(reductionOp); 305 })); 306 if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; })) 307 return b.notifyMatchFailure(op, "unknown reduction neutral"); 308 309 // TODO: relax this when multi-reduction support is available. 310 if (op.getNumOutputs() != static_cast<int64_t>(neutralElements.size())) 311 return b.notifyMatchFailure(op, "expect one reduction per output"); 312 313 // Rewrite part. 314 // Step 1. Build the intermediate outputs filled with the proper 315 // neutralElements. Such outputs are of the same shape with an extra dimension 316 // inserted at `insertSplitDimension`. 317 // 318 // Consider a minimal example where `k` is reduced: 319 // O(i, j) += I(i, j, k) 320 // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0. 321 // The compute is rewritten as: 322 // a. O_i(kk, i, j) += I(i, j, 16 * k + kk) 323 // b. O(i, j) += O_i(kk, i, j) 324 // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5. 325 Location loc = op->getLoc(); 326 MLIRContext *context = op.getContext(); 327 // For now assume outputs are 1-1 with reduction neutralElements. 328 // TODO: generalize when multi-reduction support is available. 329 SmallVector<Value> newOutputs; 330 newOutputs.reserve(op.getNumOutputs()); 331 SmallVector<Operation *> initOrAllocTensorOps; 332 SmallVector<linalg::FillOp> fillOps; 333 fillOps.reserve(op.getNumOutputs()); 334 for (auto it : llvm::zip(op.outputs(), neutralElements)) { 335 Value rankedTensor = std::get<0>(it); 336 auto t = rankedTensor.getType().cast<RankedTensorType>(); 337 RankedTensorType newT = RankedTensorType::Builder(t).insertDim( 338 reductionDimSize / splitFactor, insertSplitDimension); 339 SmallVector<Value> dims = 340 tensor::createDynamicDimValues(b, loc, rankedTensor); 341 Value initOrAllocTensor; 342 if (useAlloc) { 343 initOrAllocTensor = 344 b.create<bufferization::AllocTensorOp>(loc, newT, dims); 345 } else { 346 initOrAllocTensor = b.create<linalg::InitTensorOp>( 347 loc, dims, newT.getShape(), t.getElementType()); 348 } 349 Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it)); 350 fillOps.push_back( 351 b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor)); 352 newOutputs.push_back(fillOps.back().getResult(0)); 353 initOrAllocTensorOps.push_back(initOrAllocTensor.getDefiningOp()); 354 } 355 356 // Step 2. Reindex / expand indexing maps. 357 // Reindex existing input indexings: k -> k * splitFactor + k'. 358 SmallVector<AffineMap> newMaps; 359 newMaps.reserve(op.getNumInputsAndOutputs() + 1); 360 for (OpOperand *o : op.getInputOperands()) 361 newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor)); 362 // Provision a new indexing for the shape-only tensor. 363 auto nDims = op.getNumLoops() + 1; 364 auto redDim = getAffineDimExpr(reductionDimPos, context); 365 auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context); 366 newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context)); 367 // Expand existing output indexings. 368 // TODO: a subset of these may not reduce along reducePos and should be 369 // reindexed: k -> k * splitFactor + k', when multi-reduction support is 370 // available. 371 for (OpOperand *o : op.getOutputOperands()) 372 newMaps.push_back(insertParallelDim(op, *o, reductionDimPos, 373 reductionDimSize / splitFactor)); 374 375 // Step 3. Handle operands. 376 // Compute the new input tensors. 377 auto newInputs = llvm::to_vector<4>(op.inputs()); 378 // Add a single shape-only tensor to carry the dimensions without resorting to 379 // more complex inversions. 380 newInputs.push_back(b.create<linalg::InitTensorOp>( 381 loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor}, 382 b.getIntegerType(1))); 383 // Output tensors are already good to go. 384 385 // Step 4. Create the new op matching the original op with an extra parallel 386 // dimension. 387 SmallVector<StringRef> iteratorTypes = 388 llvm::to_vector<4>(op.getIteratorTypes().getAsValueRange<StringAttr>()); 389 iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos, 390 getParallelIteratorTypeName()); 391 GenericOp genericOp = 392 b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs, 393 newOutputs, newMaps, iteratorTypes); 394 b.inlineRegionBefore(op->getRegion(0), genericOp.region(), 395 genericOp.region().begin()); 396 genericOp.region().front().insertArgument(reductionDimPos, 397 b.getIntegerType(1), loc); 398 399 // Step 5. Create new reduction ops that only reduce the newly added 400 // dimensions from the previous op. 401 // For now assume outputs are 1-1 with reduction ops. 402 // TODO: a subset of these may not reduce in the first place and do not 403 // require a new op, when multi-reduction support is available. 404 // TODO: all results can be handled in a single GenericOp, when 405 // multi-reduction support is available. 406 SmallVector<LinalgOp> results; 407 for (auto it : 408 llvm::zip(genericOp->getResults(), op.outputs(), combinerOps)) { 409 Value reindexedOutput = std::get<0>(it); 410 Value originalOutput = std::get<1>(it); 411 auto originalOutputType = originalOutput.getType().cast<RankedTensorType>(); 412 Operation *combinerOp = std::get<2>(it); 413 414 AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1); 415 SmallVector<AffineMap> indexingMaps = { 416 map, map.dropResult(insertSplitDimension)}; 417 SmallVector<StringRef> reductionIteratorTypes( 418 originalOutputType.getRank() + 1, getParallelIteratorTypeName()); 419 reductionIteratorTypes[insertSplitDimension] = 420 getReductionIteratorTypeName(); 421 422 // clang-format off 423 auto reductionOp = b.create<GenericOp>( 424 loc, 425 originalOutputType, 426 reindexedOutput, 427 originalOutput, 428 indexingMaps, 429 reductionIteratorTypes, 430 [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) { 431 Operation *clonedReductionOp = b.clone(*combinerOp); 432 clonedReductionOp->setOperand(0, bbArgs[0]); 433 clonedReductionOp->setOperand(1, bbArgs[1]); 434 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); 435 }); 436 // clang-format on 437 438 results.push_back(reductionOp); 439 } 440 441 // TODO: extend when multi-reduction support is available. 442 assert(fillOps.size() == results.size() && results.size() == 1); 443 b.replaceOp(op, results.front()->getResults()); 444 return SplitReductionResult{initOrAllocTensorOps.front(), fillOps.front(), 445 cast<LinalgOp>(genericOp.getOperation()), 446 results.front()}; 447 } 448 449 namespace { 450 451 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> { 452 /// Construct a generic pattern applied to all LinalgOp that verify `filter`. 453 LinalgSplitReduction(MLIRContext *context, 454 ControlSplitReductionFn controlSplitReductionFn, 455 LinalgTransformationFilter f, bool useAlloc = false, 456 PatternBenefit benefit = 1) 457 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 458 controlSplitReductionFn(std::move(controlSplitReductionFn)), 459 useAlloc(useAlloc), filter(std::move(f)) {} 460 461 LogicalResult matchAndRewrite(LinalgOp op, 462 PatternRewriter &rewriter) const override { 463 return splitReduction(rewriter, op, controlSplitReductionFn, filter, 464 useAlloc); 465 } 466 467 private: 468 ControlSplitReductionFn controlSplitReductionFn; 469 bool useAlloc; 470 LinalgTransformationFilter filter; 471 }; 472 473 } // namespace 474 475 void linalg::populateSplitReductionPattern( 476 RewritePatternSet &patterns, 477 const ControlSplitReductionFn &controlSplitReductionFn, 478 const LinalgTransformationFilter &f, bool useAlloc) { 479 patterns.add<LinalgSplitReduction>(patterns.getContext(), 480 controlSplitReductionFn, f, useAlloc); 481 } 482