//===- Transforms.cpp - Linalg transformations as patterns ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements logic and helpers to expose Linalg transforms as rewrite // patterns. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include #include #define DEBUG_TYPE "linalg-transforms" using namespace mlir; using namespace mlir::linalg; #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. //===----------------------------------------------------------------------===// // Marker used as attribute name in generated Linalg rewriting transformations. const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = "__internal_linalg_transform__"; mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( ArrayRef matchDisjunction, Optional replacement) : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), replacement(replacement), matchByDefault(false) {} mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( const FilterFunction &f, ArrayRef matchDisjunction, Optional replacement) : filters(), matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), replacement(replacement), matchByDefault(false) { if (f) filters.push_back(f); } LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( PatternRewriter &rewriter, Operation *op) const { if (llvm::any_of(filters, [&](const FilterFunction &f) { return failed(f(op)); })) return failure(); auto attr = op->template getAttrOfType( LinalgTransforms::kLinalgTransformMarker); if (!attr) { // 1. Has no filter case and matchDisjunction is empty. if (matchDisjunction.empty() || matchByDefault) return success(); // 2. Has no filter but was expecting a filter. return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diag << " does not have any filter from list: "; interleaveComma(matchDisjunction, diag); }); } // 4. Match explicit filter. for (auto filter : matchDisjunction) if (attr.getValue() == filter) return success(); // 5. Fail to match. return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diag << " does not have any filter from list: "; interleaveComma(matchDisjunction, diag); }); } void mlir::linalg::LinalgTransformationFilter:: replaceLinalgTransformationFilter(PatternRewriter &rewriter, Operation *op) const { if (replacement.has_value()) op->setAttr(LinalgTransforms::kLinalgTransformMarker, replacement.value()); else op->removeAttr( rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); } bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( Operation *op) const { if (!replacement) return false; auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) .dyn_cast(); return attr && attr == *replacement; } LinalgTilingOptions & mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef ts) { assert(!tileSizeComputationFunction && "tile sizes already set"); SmallVector tileSizes(ts.begin(), ts.end()); tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart( &op->getParentOfType().getBody().front()); return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { Value v = b.create(op->getLoc(), s); return v; })); }; return *this; } LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { assert(!tileSizeComputationFunction && "tile sizes already set"); tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { SmallVector tileSizes; auto linalgOp = dyn_cast(op); if (!linalgOp) return tileSizes; Location loc = linalgOp.getLoc(); auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); AffineMap map = linalgOp.getShapesToLoopsMap(); if (!map) return tileSizes; auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile // size 0). for (Value shapeSize : shapeSizes) tileSizes.push_back(getConstantIntValue(shapeSize) ? b.create(loc, 0) : b.create(loc, 1)); return tileSizes; }; return *this; } /// Pad the `opOperand` in the `paddingDimensions` using the padding value and /// the nofold flag found in `paddingValues` and `packPaddings`, respectively. /// Exit early and return the `opOperand` value if the shape dimensions that /// match `paddingDimensions` have a static size and the nofold flag is not set. /// Otherwise, try to pad the shape dimensions that match the iterator /// dimensions `paddingDimensions` and return the tensor::PadOp result if /// padding succeeds or failure otherwise. static FailureOr padOperandToSmallestStaticBoundingBox( OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, ArrayRef paddingDimensions, ArrayRef paddingValues, ArrayRef packPaddings) { AffineMap indexingMap = opToPad.getTiedIndexingMap(opOperand); ArrayRef shape = opToPad.getShape(opOperand); // Collect the shape dimension that are a function of the `paddingDimensions`. llvm::SmallDenseSet shapeDimsToPad; for (int64_t dim : paddingDimensions) for (const auto &en : enumerate(indexingMap.getResults())) if (en.value().isFunctionOfDim(dim)) shapeDimsToPad.insert(en.index()); // Return the unpadded operand if padding to a static shape is not needed and // if the nofold flag is not set. bool nofold = opOperand->getOperandNumber() < packPaddings.size() ? packPaddings[opOperand->getOperandNumber()] : false; bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) { return ShapedType::isDynamic(shape[dim]); }); if (!nofold && hasStaticShape) return opOperand->get(); // Fail if `paddingValues` specifies no padding value. if (opOperand->getOperandNumber() >= paddingValues.size()) return failure(); Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()]; Value paddingValue = b.create( opToPad.getLoc(), paddingAttr.getType(), paddingAttr); // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. OpOperand *currOpOperand = opOperand; while (auto linalgOp = currOpOperand->get().getDefiningOp()) { OpResult result = currOpOperand->get().cast(); currOpOperand = linalgOp.getOutputOperand(result.getResultNumber()); } // Fail if `currOpOperand` is not defined by an ExtractSliceOp. auto sliceOp = currOpOperand->get().getDefiningOp(); if (!sliceOp) return failure(); // Compute the dropped dimensions if `sliceOp` is ranke-reducing. llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); OffsetSizeAndStrideOpInterface shapedOp = sliceOp; // Upper bound the `sliceOp` sizes to obtain a static bounding box. SmallVector paddedShape(shape.begin(), shape.end()); int64_t shapeIdx = 0; for (const auto &en : enumerate(shapedOp.getMixedSizes())) { // Skip dropped dimensions. if (droppedDims.test(en.index())) continue; // Skip dimensions that do not require padding. if (!shapeDimsToPad.contains(shapeIdx)) { shapeIdx++; continue; } // If the size is an attribute add it directly to `paddedShape`. if (en.value().is()) { paddedShape[shapeIdx++] = en.value().get().dyn_cast().getInt(); continue; } // Otherwise, try to compute a constant upper bound for the size value. FailureOr upperBound = getConstantUpperBoundForIndex(en.value().get()); if (failed(upperBound)) { LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); return failure(); } paddedShape[shapeIdx++] = *upperBound; } assert(shapeIdx == static_cast(shape.size()) && "expect the dynamic and static ranks to match"); // Pad the operand to the bounding box defined by `paddedShape`. auto paddedTensorType = RankedTensorType::get( paddedShape, getElementTypeOrSelf(opOperand->get())); return makeComposedPadHighOp(b, opToPad->getLoc(), paddedTensorType, opOperand->get(), paddingValue, nofold); } FailureOr> linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, ArrayRef paddingDimensions, ArrayRef paddingValues, ArrayRef packPaddings, LinalgOp &paddedOp) { Location loc = opToPad->getLoc(); // TODO: there are cases where we may still want to pad to larger sizes. assert(opToPad.hasTensorSemantics() && "expected operation to have tensor semantics"); OpBuilder::InsertionGuard g(b); // Set IP after op because we also take the dims of the original output. b.setInsertionPointAfter(opToPad); // Make a copy of the shaped operands and update it. SmallVector newOperands; newOperands.reserve(opToPad.getNumInputsAndOutputs()); for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { FailureOr paddedOperand = padOperandToSmallestStaticBoundingBox( b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings); // Exit if `paddingDimensions` cannot be bounded statically. if (failed(paddedOperand)) return failure(); newOperands.push_back(*paddedOperand); } SmallVector> reifiedResultShapes; if (failed(cast(opToPad.getOperation()) .reifyResultShapes(b, reifiedResultShapes))) return failure(); assert(reifiedResultShapes.size() == opToPad->getNumResults() && "expected same number of results"); // Clone `opToPad` to operate on the statically padded shapes. auto resultTensorTypes = ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); // Recover the slice out of the new static results. This keeps the original // linalg op around because it uses the dims of the original results. SmallVector paddedSubviewResults; paddedSubviewResults.reserve(opToPad->getNumResults()); for (const auto &en : llvm::enumerate(paddedOp->getResults())) { Value paddedResult = en.value(); int64_t resultNumber = en.index(); int64_t rank = paddedResult.getType().cast().getRank(); SmallVector offsets(rank, b.getIndexAttr(0)); SmallVector sizes; for (Value v : reifiedResultShapes[resultNumber]) sizes.push_back(getAsOpFoldResult(v)); SmallVector strides(rank, b.getIndexAttr(1)); paddedSubviewResults.push_back(b.create( loc, paddedResult, offsets, sizes, strides)); } return paddedSubviewResults; } /// Try to peel a loop `op` and return the new result. // TODO: Add support for scf.parallel and affine.for loops. static SmallVector peelLoop(RewriterBase &rewriter, Operation *op) { return llvm::TypeSwitch>(op) .Case([&](scf::ForOp forOp) { scf::ForOp partialIteration; if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, partialIteration))) return partialIteration->getResults(); assert(!partialIteration && "expected that loop was not peeled"); return forOp->getResults(); }) .Default([&](Operation *op) { return op->getResults(); }); } /// Peel and canonicalize 'loops'. void mlir::linalg::peelLoops(RewriterBase &rewriter, ArrayRef loops) { for (auto loopOp : loops) { SmallVector loopResults; loopResults = peelLoop(rewriter, loopOp); } } /// Peel loops after tiling. void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, ArrayRef peeledLoops, LinalgTilingLoopType loopType) { for (int64_t loop : peeledLoops) { assert(loop < static_cast(res.loops.size()) && "requested peeling of non-existing loop"); SmallVector loopResults; Operation *loopOp = res.loops[loop]; loopResults = peelLoop(rewriter, loopOp); // The result of the loop nest may change with peeling. if (res.tensorResults.size() == loopOp->getNumResults() && std::equal(res.tensorResults.begin(), res.tensorResults.end(), loopOp->getResults().begin())) res.tensorResults = loopResults; } } /// Linalg tiling pattern. mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( MLIRContext *context, LinalgTilingOptions options, LinalgTransformationFilter f, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), filter(std::move(f)), options(std::move(options)) {} mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( StringRef opName, MLIRContext *context, LinalgTilingOptions options, LinalgTransformationFilter f, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), filter(f.addOpNameFilter(opName)), options(std::move(options)) {} FailureOr mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite( LinalgOp op, PatternRewriter &rewriter) const { if (failed(filter.checkAndNotify(rewriter, op))) return failure(); FailureOr res = tileLinalgOp(rewriter, op, options); if (failed(res)) return failure(); // Clear filter to stop recursive pattern application. // This must be done here to properly propagate to peeling branches. filter.replaceLinalgTransformationFilter(rewriter, res->op); // Peel the loops of the TiledLinalgOp. peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType); if (res->tensorResults.empty()) rewriter.eraseOp(op); else rewriter.replaceOp(op, res->tensorResults); return res; } /// Linalg padding pattern. mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( MLIRContext *context, LinalgPaddingOptions options, LinalgTransformationFilter f, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), filter(std::move(f)), options(std::move(options)) {} mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( StringRef opName, MLIRContext *context, LinalgPaddingOptions options, LinalgTransformationFilter f, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), filter(f.addOpNameFilter(opName)), options(std::move(options)) {} FailureOr mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite( LinalgOp linalgOp, PatternRewriter &rewriter) const { if (!linalgOp.hasTensorSemantics()) return failure(); if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); // Pad the operation. LinalgOp paddedOp; FailureOr> newResults = rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions, options.paddingValues, options.packPaddings, paddedOp); if (failed(newResults)) return failure(); // Hoist the padding. for (const auto &en : enumerate(options.hoistPaddings)) { if (static_cast(en.index()) >= paddedOp.getNumInputsAndOutputs()) break; OpOperand *opOperand = &paddedOp->getOpOperand(en.index()); auto padOp = opOperand->get().getDefiningOp(); if (!padOp || en.value() == 0) continue; // Fail hoisting if the operand shape is not fully static. if (llvm::any_of(paddedOp.getShape(opOperand), ShapedType::isDynamic)) return failure(); tensor::PadOp hoistedOp; SmallVector transposeOps; SmallVector transposeVector = en.index() < options.transposePaddings.size() ? options.transposePaddings[en.index()] : SmallVector{}; FailureOr newResult = hoistPaddingOnTensors( padOp, en.value(), transposeVector, hoistedOp, transposeOps); if (failed(newResult)) continue; rewriter.replaceOp(padOp, *newResult); // Do not apply hoist padding to the newly introduced transpose operations. for (GenericOp transposeOp : transposeOps) filter.replaceLinalgTransformationFilter(rewriter, transposeOp); } // Replace the original operation to pad. rewriter.replaceOp(linalgOp, *newResults); filter.replaceLinalgTransformationFilter(rewriter, paddedOp); return paddedOp; } /// Linalg tile and fuse tensor ops pattern. mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, LinalgTilingAndFusionOptions options, LinalgTransformationFilter f, PatternBenefit benefit) : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(std::move(f)), options(std::move(options)) {} mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, LinalgTilingAndFusionOptions options, LinalgTransformationFilter f, PatternBenefit benefit) : RewritePattern(opName, benefit, context), filter(std::move(f)), options(std::move(options)) {} FailureOr mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite( Operation *op, PatternRewriter &rewriter) const { LinalgOp rootOp = dyn_cast(op); if (!rootOp) return failure(); if (failed(filter.checkAndNotify(rewriter, op))) return failure(); // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. if (options.tileSizes.size() < rootOp.getNumLoops()) return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); // Check `tileInterchange` contains no entries or as many as `tileSizes`. if (!options.tileInterchange.empty() && options.tileInterchange.size() != options.tileSizes.size()) return rewriter.notifyMatchFailure( op, "expect the number of tile sizes and interchange dims to match"); // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. SmallVector rootTileSizes(options.tileSizes.begin(), options.tileSizes.begin() + rootOp.getNumLoops()); SmallVector rootInterchange = options.tileInterchange.empty() ? llvm::to_vector<6>(llvm::seq(0, rootOp.getNumLoops())) : SmallVector(options.tileInterchange.begin(), options.tileInterchange.begin() + rootOp.getNumLoops()); // Check `rootTileSizes` contains non-zero tile sizes. if (llvm::count(rootTileSizes, 0) == static_cast(rootTileSizes.size())) return rewriter.notifyMatchFailure( op, "expect at least one non-zero tile size"); // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. // It has to be a permutation since the tiling cannot tile the same loop // dimension multiple times. if (!isPermutation(rootInterchange)) return rewriter.notifyMatchFailure( op, "expect the tile interchange permutes the root loops"); // Tile `rootOp` and fuse its producers. FailureOr tileLoopNest = tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes, rootInterchange, options.tileDistribution); if (failed(tileLoopNest)) return rewriter.notifyMatchFailure( op, "tileConsumerAndFuseProducers failed unexpectedly"); // Replace all uses of the tiled loop operation. rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); // Apply the filter if specified. for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) filter.replaceLinalgTransformationFilter(rewriter, linalgOp); return tileLoopNest; } /// Linalg generic interchange pattern. mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( MLIRContext *context, ArrayRef interchangeVector, LinalgTransformationFilter f, PatternBenefit benefit) : OpRewritePattern(context, benefit), filter(std::move(f)), interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} FailureOr mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite( GenericOp genericOp, PatternRewriter &rewriter) const { if (failed(filter.checkAndNotify(rewriter, genericOp))) return failure(); FailureOr transformedOp = interchangeGenericOp(rewriter, genericOp, interchangeVector); if (failed(transformedOp)) return failure(); // New filter if specified. filter.replaceLinalgTransformationFilter(rewriter, genericOp); return transformedOp; } /// Linalg generalization pattern. mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), filter(std::move(f)) {} mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( StringRef opName, MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), filter(f.addOpNameFilter(opName)) {} FailureOr mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite( LinalgOp linalgOp, PatternRewriter &rewriter) const { if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); FailureOr genericOp = generalizeNamedOp(rewriter, linalgOp); if (failed(genericOp)) return failure(); filter.replaceLinalgTransformationFilter(rewriter, *genericOp); return genericOp; } mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern( MLIRContext *context, LinalgTransformationFilter f, LinalgPeelOptions options, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), filter(std::move(f)), options(std::move(options)) {} mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern( StringRef opName, MLIRContext *context, LinalgPeelOptions options, LinalgTransformationFilter f, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), filter(f.addOpNameFilter(opName)), options(std::move(options)) {} LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite( LinalgOp linalgOp, PatternRewriter &rewriter) const { if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); // Increase marker counter even if peeling doesn't happen for this op. filter.replaceLinalgTransformationFilter(rewriter, linalgOp); if (!options.loopsToPeelComputationFunction) return failure(); SmallVector loopsToPeel; options.loopsToPeelComputationFunction(rewriter, linalgOp, loopsToPeel); peelLoops(rewriter, loopsToPeel); return success(); } mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( MLIRContext *context, LinalgTransformationFilter f, LinalgVectorizationOptions options, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), filter(std::move(f)) {} mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( StringRef opName, MLIRContext *context, LinalgVectorizationOptions options, LinalgTransformationFilter f, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), filter(f.addOpNameFilter(opName)) {} LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite( LinalgOp linalgOp, PatternRewriter &rewriter) const { if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); return vectorize(rewriter, linalgOp); } LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( memref::CopyOp copyOp, PatternRewriter &rewriter) const { return vectorizeCopy(rewriter, copyOp); } LogicalResult mlir::linalg::applyStagedPatterns( Operation *op, ArrayRef stage1Patterns, const FrozenRewritePatternSet &stage2Patterns, function_ref stage3Lambda) { unsigned iteration = 0; (void)iteration; for (const auto &patterns : stage1Patterns) { LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" << *op); if (failed(applyPatternsAndFoldGreedily(op, patterns))) { LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); return failure(); } LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" << *op); if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); return failure(); } LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" << *op); if (stage3Lambda) { if (failed(stage3Lambda(op))) return failure(); LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" << *op); } } return success(); } static SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops) { return SmallVector(nParallelLoops, getParallelIteratorTypeName()); } /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to /// initialize with pad_val) and GenericOp (to copy contents). LogicalResult PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const { auto inputShapedType = padOp.getSource().getType().cast(); auto resultShapedType = padOp.getResult().getType().cast(); // Bail on non-static shapes. if (!inputShapedType.hasStaticShape()) return failure(); if (!resultShapedType.hasStaticShape()) return failure(); // Only support padding with a constant for now, i.e. either: // 1. A BBarg from a different block. // 2. A value defined outside of the current block. Block &block = padOp.getRegion().front(); auto yieldOp = cast(block.getTerminator()); Value padValue = yieldOp.getValue(); Operation *definingOp = padValue.getDefiningOp(); if (definingOp && definingOp->getBlock() == &block) return failure(); if (!definingOp && padValue.cast().getOwner() == &block) return failure(); // Create tensor with the padded shape Location loc = padOp.getLoc(); SmallVector indices(resultShapedType.getRank(), rewriter.create(loc, 0)); Value initTensor = rewriter.create( loc, resultShapedType.getShape(), resultShapedType.getElementType()); // Initialize tensor with the pad value Value tmpTensor = rewriter .create(loc, ValueRange{padValue}, ValueRange{initTensor}) .result(); // Copy original contents into new tensor // Uses linalg.generic, but could be done with tensor.insert_slice SmallVector outputExprs; for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + padOp.getStaticLow()[i].cast().getInt()); } SmallVector transferMaps = { rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), AffineMap::get(resultShapedType.getRank(), /*symbolCount=*/0, outputExprs, rewriter.getContext())}; rewriter.replaceOpWithNewOp( padOp, resultShapedType, padOp.getSource(), tmpTensor, transferMaps, getNParallelLoopsAttrs(resultShapedType.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(nestedLoc, args[0]); }); return success(); } /// Filling `dest` using FillOp constant padding value if possible. /// Otherwise, generate a tensor::GenerateOp. Value GeneralizePadOpPattern::createFillOrGenerateOp( PatternRewriter &rewriter, tensor::PadOp padOp, Value dest, const SmallVector &dynSizes) const { auto padValue = padOp.getConstantPaddingValue(); if (padValue) return rewriter.create(padOp.getLoc(), padValue, dest).result(); // Fill could not be optimized: Lower to tensor::GenerateOp with region. auto generateOp = rewriter.create( padOp.getLoc(), padOp.getResultType(), dynSizes); // Copy region to new op. BlockAndValueMapping bvm; padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm); return generateOp; } LogicalResult GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const { // Given an OpFoldResult, return an index-typed value. auto getIdxValue = [&](OpFoldResult ofr) { if (auto val = ofr.dyn_cast()) return val; return rewriter .create( padOp.getLoc(), ofr.get().cast().getInt()) .getResult(); }; auto resultType = padOp.getResultType(); // Compute size of InitTensorOp. Any combination of static/dynamic is // supported. SmallVector dynSizes; SmallVector staticSizes; for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { if (resultType.isDynamicDim(dim)) { auto srcSize = rewriter.createOrFold( padOp.getLoc(), padOp.getSource(), dim); // Add low and high padding value. auto plusLow = rewriter.createOrFold( padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); auto plusHigh = rewriter.createOrFold( padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); dynSizes.push_back(plusHigh); } staticSizes.push_back(resultType.getDimSize(dim)); } // Init tensor and fill it with padding. Value init = rewriter.create( padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); // Try optimize the copy of source. if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) return success(); // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead // for copying the PadOp source. auto sourceType = padOp.getSourceType(); // Compute size of source of tensor::PadOp. SmallVector srcSizes; for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { if (sourceType.isDynamicDim(dim)) { srcSizes.push_back(rewriter.createOrFold( padOp.getLoc(), padOp.getSource(), dim)); } else { srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); } } // Strides of InsertSliceOp are all 1. SmallVector strides(sourceType.getRank(), rewriter.getIndexAttr(1)); rewriter.replaceOpWithNewOp( padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes, strides); return success(); } LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { if (!sliceOp.hasUnitStride()) return failure(); auto padOp = sliceOp.getSource().getDefiningOp(); if (!padOp) return failure(); bool zeroSliceGuard = true; if (controlFn) { if (Optional control = controlFn(sliceOp)) zeroSliceGuard = *control; else return failure(); } Operation *tiledPadOp = tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), zeroSliceGuard); // All shapes are static and the data source is actually used. Rewrite into // pad(extract_slice(x)). rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); return success(); } // The following are patterns for downscaling convolution ops with size-1 // window dimensions. // // Note that we'd eventually want to write such transformations in a generic // way, e.g., converting to linalg.generic, removing the size-1 dimensions, // and then turning back to named ops. But for now it's fine to have a few // patterns matching special ops to get started. FailureOr DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite( linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const { if (failed(filter.checkAndNotify(rewriter, convOp))) return failure(); if (convOp.hasBufferSemantics()) return failure(); // To be implemented. Value input = convOp.inputs().front(); Value kernel = convOp.inputs().back(); Value output = convOp.outputs().front(); auto inputType = input.getType().dyn_cast(); auto kernelType = kernel.getType().dyn_cast(); auto outputType = output.getType().dyn_cast(); auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); // Only handle the case where at least one of the window dimensions is // of size 1. Other cases can rely on tiling to reduce to such cases. int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; int64_t ohSize = outputShape[1], owSize = outputShape[2]; bool removeH = (khSize == 1 && ohSize == 1); bool removeW = (kwSize == 1 && owSize == 1); if (!removeH && !removeW) return failure(); // Get new shapes and types for all operands by removing the size-1 // dimension. using RTTBuilder = RankedTensorType::Builder; RankedTensorType newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); RankedTensorType newKernelType = RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); RankedTensorType newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2); // Rank-reduce operands. Location loc = convOp.getLoc(); Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, input, newInputType); Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, kernel, newKernelType); Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, output, newOutputType); // Rank-reduce strides and dilations too. // TODO: dropDim 1-liner helper. auto strides = llvm::to_vector<4>(convOp.strides().getValues()); strides.erase(strides.begin() + (removeH ? 0 : 1)); auto stridesAttr = rewriter.getI64VectorAttr(strides); auto dilations = llvm::to_vector<4>(convOp.dilations().getValues()); dilations.erase(dilations.begin() + (removeH ? 0 : 1)); auto dilationsAttr = rewriter.getI64VectorAttr(dilations); auto conv1DOp = rewriter.create( loc, newOutputType, ValueRange{newInput, newKernel}, ValueRange{newOutput}, stridesAttr, dilationsAttr); // Insert back. Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( rewriter, loc, conv1DOp.getResult(0), output); rewriter.replaceOp(convOp, inserted); filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); return conv1DOp; } FailureOr DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const { if (failed(filter.checkAndNotify(rewriter, convOp))) return failure(); if (convOp.hasBufferSemantics()) return failure(); // To be implemented. Value input = convOp.inputs().front(); Value kernel = convOp.inputs().back(); Value output = convOp.outputs().front(); auto inputType = input.getType().dyn_cast(); auto kernelType = kernel.getType().dyn_cast(); auto outputType = output.getType().dyn_cast(); auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); // Only handle the case where at least one of the window dimensions is // of size 1. Other cases can rely on tiling to reduce to such cases. int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; int64_t ohSize = outputShape[1], owSize = outputShape[2]; bool removeH = (khSize == 1 && ohSize == 1); bool removeW = (kwSize == 1 && owSize == 1); if (!removeH && !removeW) return failure(); // Get new shapes and types for all operands by removing the size-1 // dimension. using RTTBuilder = RankedTensorType::Builder; RankedTensorType newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); RankedTensorType newKernelType = RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); RankedTensorType newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2); // Rank-reduce operands. Location loc = convOp.getLoc(); Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, input, newInputType); Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, kernel, newKernelType); Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, output, newOutputType); // Rank-reduce strides and dilations too. // TODO: dropDim 1-liner helper. auto strides = llvm::to_vector<4>(convOp.strides().getValues()); strides.erase(strides.begin() + (removeH ? 0 : 1)); auto stridesAttr = rewriter.getI64VectorAttr(strides); auto dilations = llvm::to_vector<4>(convOp.dilations().getValues()); dilations.erase(dilations.begin() + (removeH ? 0 : 1)); auto dilationsAttr = rewriter.getI64VectorAttr(dilations); auto conv1DOp = rewriter.create( loc, newOutputType, ValueRange{newInput, newKernel}, ValueRange{newOutput}, stridesAttr, dilationsAttr); // Insert back. Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( rewriter, loc, conv1DOp.getResult(0), output); rewriter.replaceOp(convOp, inserted); filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); return conv1DOp; } void linalg::populateDecomposeConvolutionPatterns( RewritePatternSet &patterns, const LinalgTransformationFilter &filter, PatternBenefit benefit) { patterns.add(patterns.getContext(), filter, benefit); }