//===-------- SplitReduction.cpp - Split reduction dimesion ---------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements linalg transformation to break a reduction dimension // between a parallel and a reduction dimension. // //===----------------------------------------------------------------------===// #include #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::linalg; /// Return the identity numeric value associated to the give op. static Optional getIdentity(Operation *op) { // Builder only used as helper for attribute creation. OpBuilder b(op->getContext()); Type resultType = op->getResult(0).getType(); if (auto floatType = resultType.dyn_cast()) { const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); if (isa(op)) return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); if (isa(op)) return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); if (isa(op)) return b.getFloatAttr(resultType, llvm::APFloat::getLargest(semantic, true)); if (isa(op)) return b.getFloatAttr(resultType, llvm::APFloat::getLargest(semantic, true)); return llvm::None; } if (isa(op)) return b.getIntegerAttr(resultType, 0); if (isa(op)) return b.getIntegerAttr(resultType, -1); if (isa(op)) return b.getIntegerAttr(resultType, std::numeric_limits::min()); if (isa(op)) return b.getIntegerAttr(resultType, std::numeric_limits::max()); if (isa(op)) return b.getIntegerAttr(resultType, 1); return llvm::None; } FailureOr mlir::linalg::splitReduction( PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, const LinalgTransformationFilter &filter) { if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() || op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 || !op.hasOnlyProjectedPermutations()) return b.notifyMatchFailure(op, "precondition not met"); std::pair control = controlSplitReductionFn(op); int64_t ratio = control.first; unsigned insertDimIndex = control.second; if (ratio <= 1) return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); SmallVector dims; op.getReductionDims(dims); assert(dims.size() == 1); unsigned reductionDim = dims[0]; Optional> loopRanges = op.getStaticLoopRanges(); if (!loopRanges) return b.notifyMatchFailure(op, "Cannot analyze loops"); int64_t reductionDimSize = (*loopRanges)[reductionDim]; if (reductionDimSize == ShapedType::kDynamicSize || reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges->size()) return b.notifyMatchFailure( op, "Reduction dimension not divisible by split ratio"); SmallVector combinerOps; if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || combinerOps.size() != 1) return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); Operation *reductionOp = combinerOps[0]; Optional identity = getIdentity(reductionOp); if (!identity) return b.notifyMatchFailure(op, "Unknown identity value for the redution"); Location loc = op->getLoc(); SmallVector newInputs; SmallVector newMaps; // Calculate the new shapes and indexing maps of the input operands. for (OpOperand *operand : op.getInputOperands()) { AffineMap map = op.getTiedIndexingMap(operand); SmallVector newShape; SmallVector exprs; SmallVector reassociation; unsigned index = 0; for (unsigned idx : llvm::seq(0, map.getNumResults())) { unsigned dim = map.getDimPosition(idx); if (reductionDim == dim) { newShape.push_back(ratio); newShape.push_back(op.getShape(operand)[idx] / ratio); reassociation.push_back({index++, index++}); exprs.push_back(b.getAffineDimExpr(insertDimIndex)); exprs.push_back( b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); continue; } newShape.push_back(op.getShape(operand)[idx]); exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); reassociation.push_back({index++}); } newMaps.push_back( AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext())); // If the shape is unchanged the input doesn't change. if (newShape == op.getShape(operand)) { newInputs.push_back(operand->get()); continue; } Type newType = RankedTensorType::get( newShape, operand->get().getType().cast().getElementType()); Value newInput = b.create( loc, newType, operand->get(), reassociation); newInputs.push_back(newInput); } // Calculate the new output map and shape, we insert the new dimension based // on the index returned by `controlSplitReductionFn`. SmallVector newOutputShape; AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0)); ArrayRef oldShape = op.getShape(op.getOutputOperand(0)); SmallVector outputExpr; for (unsigned idx : llvm::seq(0, oldOutputMap.getNumResults() + 1)) { if (idx == insertDimIndex) { newOutputShape.push_back(ratio); outputExpr.push_back(b.getAffineDimExpr(insertDimIndex)); continue; } unsigned oldDim = idx < insertDimIndex ? idx : idx - 1; newOutputShape.push_back(oldShape[oldDim]); unsigned dim = oldOutputMap.getDimPosition(oldDim); outputExpr.push_back( b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); } Value initTensor = b.create( loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); Value constantOp = b.create(loc, *identity); Value identityTensor = b.create(op->getLoc(), constantOp, initTensor) .getResult(0); newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, op.getContext())); SmallVector newIteratorTypes; for (auto &it : llvm::enumerate(op.iterator_types())) { if (insertDimIndex == it.index()) newIteratorTypes.push_back(getParallelIteratorTypeName()); newIteratorTypes.push_back(it.value().cast().getValue()); } // Create the new op matching the original op with an extra parallel // dimension. GenericOp genericOp = b.create( loc, TypeRange({initTensor.getType()}), newInputs, ValueRange({identityTensor}), newMaps, newIteratorTypes); b.inlineRegionBefore(op->getRegion(0), genericOp.region(), genericOp.region().begin()); // Then create a new reduction that only reduce the newly added dimension from // the previous op. unsigned intermRank = newOutputShape.size(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); SmallVector outputOperands = op.getOutputOperands(); SmallVector reductionIteratorTypes; SmallVector exprs; for (unsigned i : llvm::seq(0, intermRank)) { if (insertDimIndex == i) { reductionIteratorTypes.push_back(getReductionIteratorTypeName()); } else { exprs.push_back(b.getAffineDimExpr(i)); reductionIteratorTypes.push_back(getParallelIteratorTypeName()); } } AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); SmallVector reductionMaps = {inputMap, outputMap}; auto reduction = b.create( loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), outputOperands, reductionMaps, reductionIteratorTypes, [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { Operation *clonedReductionOp = b.clone(*reductionOp); clonedReductionOp->setOperand(0, inputs[0]); clonedReductionOp->setOperand(1, inputs[1]); b.create(loc, clonedReductionOp->getResult(0)); }); b.replaceOp(op, reduction.getResults()); filter.replaceLinalgTransformationFilter(b, genericOp); filter.replaceLinalgTransformationFilter(b, reduction); return cast(genericOp.getOperation()); } namespace { struct LinalgSplitReduction : public OpInterfaceRewritePattern { /// Construct a generic pattern applied to all LinalgOp that verify `filter`. LinalgSplitReduction(MLIRContext *context, ControlSplitReductionFn controlSplitReductionFn, LinalgTransformationFilter f, PatternBenefit benefit = 1) : OpInterfaceRewritePattern(context, benefit), controlSplitReductionFn(std::move(controlSplitReductionFn)), filter(std::move(f)) {} LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { return splitReduction(rewriter, op, controlSplitReductionFn, filter); } private: ControlSplitReductionFn controlSplitReductionFn; LinalgTransformationFilter filter; }; } // namespace void linalg::populateSplitReductionPattern( RewritePatternSet &patterns, const ControlSplitReductionFn &controlSplitReductionFn, const LinalgTransformationFilter &f) { patterns.add(patterns.getContext(), controlSplitReductionFn, f); }