1307cfdf5SNicolas Vasilache //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2307cfdf5SNicolas Vasilache // 3307cfdf5SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4307cfdf5SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information. 5307cfdf5SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6307cfdf5SNicolas Vasilache // 7307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===// 8307cfdf5SNicolas Vasilache // 9307cfdf5SNicolas Vasilache // This file implements logic and helpers to expose Linalg transforms as rewrite 10307cfdf5SNicolas Vasilache // patterns. 11307cfdf5SNicolas Vasilache // 12307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===// 13307cfdf5SNicolas Vasilache 14307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 153afbfb41SThomas Raoux #include "mlir/Dialect/Affine/Utils.h" 16a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 18307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 19d0ec4a8eSTobias Gysi #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 20307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Utils/Utils.h" 218faf35c0SMatthias Springer #include "mlir/Dialect/SCF/Transforms.h" 22060208b4SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 23d624c1b5SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h" 24307cfdf5SNicolas Vasilache #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 25307cfdf5SNicolas Vasilache #include "mlir/Dialect/Vector/VectorOps.h" 26307cfdf5SNicolas Vasilache #include "mlir/IR/AffineExpr.h" 27307cfdf5SNicolas Vasilache #include "mlir/IR/Matchers.h" 28307cfdf5SNicolas Vasilache #include "mlir/Pass/Pass.h" 29307cfdf5SNicolas Vasilache #include "mlir/Support/LLVM.h" 30b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 313747eb9cSNicolas Vasilache #include "llvm/ADT/ScopeExit.h" 328faf35c0SMatthias Springer #include "llvm/ADT/TypeSwitch.h" 33307cfdf5SNicolas Vasilache #include "llvm/Support/Debug.h" 34307cfdf5SNicolas Vasilache #include "llvm/Support/raw_ostream.h" 35307cfdf5SNicolas Vasilache #include <type_traits> 36307cfdf5SNicolas Vasilache 37307cfdf5SNicolas Vasilache #define DEBUG_TYPE "linalg-transforms" 38307cfdf5SNicolas Vasilache 39307cfdf5SNicolas Vasilache using namespace mlir; 40307cfdf5SNicolas Vasilache using namespace mlir::linalg; 41307cfdf5SNicolas Vasilache 4256ce65e2SNicolas Vasilache #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 433110e7b0SNicolas Vasilache 44307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===// 45307cfdf5SNicolas Vasilache // Transformations exposed as rewrite patterns. 46307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===// 47307cfdf5SNicolas Vasilache // Marker used as attribute name in generated Linalg rewriting transformations. 48307cfdf5SNicolas Vasilache const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 49307cfdf5SNicolas Vasilache "__internal_linalg_transform__"; 50307cfdf5SNicolas Vasilache 51299cc5daSNicolas Vasilache mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 52195730a6SRiver Riddle ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement) 53e4a503a2SNicolas Vasilache : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 543ecc2a63SMaheshRavishankar replacement(replacement), matchByDefault(false) {} 55299cc5daSNicolas Vasilache 56299cc5daSNicolas Vasilache mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 57195730a6SRiver Riddle FilterFunction f, ArrayRef<StringAttr> matchDisjunction, 58195730a6SRiver Riddle Optional<StringAttr> replacement) 59e4a503a2SNicolas Vasilache : filters(), 60299cc5daSNicolas Vasilache matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 613ecc2a63SMaheshRavishankar replacement(replacement), matchByDefault(false) { 62e4a503a2SNicolas Vasilache if (f) 63e4a503a2SNicolas Vasilache filters.push_back(f); 64e4a503a2SNicolas Vasilache } 65307cfdf5SNicolas Vasilache 66299cc5daSNicolas Vasilache LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 67299cc5daSNicolas Vasilache PatternRewriter &rewriter, Operation *op) const { 68e4a503a2SNicolas Vasilache if (llvm::any_of(filters, 69e4a503a2SNicolas Vasilache [&](const FilterFunction &f) { return failed(f(op)); })) 70299cc5daSNicolas Vasilache return failure(); 71299cc5daSNicolas Vasilache 72307cfdf5SNicolas Vasilache auto attr = op->template getAttrOfType<StringAttr>( 73307cfdf5SNicolas Vasilache LinalgTransforms::kLinalgTransformMarker); 74307cfdf5SNicolas Vasilache 75307cfdf5SNicolas Vasilache if (!attr) { 76e4a503a2SNicolas Vasilache // 1. Has no filter case and matchDisjunction is empty. 773ecc2a63SMaheshRavishankar if (matchDisjunction.empty() || matchByDefault) 78307cfdf5SNicolas Vasilache return success(); 79307cfdf5SNicolas Vasilache 80e4a503a2SNicolas Vasilache // 2. Has no filter but was expecting a filter. 81307cfdf5SNicolas Vasilache return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 82e4a503a2SNicolas Vasilache diag << " does not have any filter from list: "; 8391beb517SNicolas Vasilache interleaveComma(matchDisjunction, diag); 84307cfdf5SNicolas Vasilache }); 85307cfdf5SNicolas Vasilache } 86307cfdf5SNicolas Vasilache 87e4a503a2SNicolas Vasilache // 4. Match explicit filter. 88e4a503a2SNicolas Vasilache for (auto filter : matchDisjunction) 89e4a503a2SNicolas Vasilache if (attr.getValue() == filter) 90307cfdf5SNicolas Vasilache return success(); 91307cfdf5SNicolas Vasilache 92307cfdf5SNicolas Vasilache // 5. Fail to match. 93307cfdf5SNicolas Vasilache return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 94e4a503a2SNicolas Vasilache diag << " does not have any filter from list: "; 9591beb517SNicolas Vasilache interleaveComma(matchDisjunction, diag); 96307cfdf5SNicolas Vasilache }); 97307cfdf5SNicolas Vasilache } 98307cfdf5SNicolas Vasilache 99299cc5daSNicolas Vasilache void mlir::linalg::LinalgTransformationFilter:: 100299cc5daSNicolas Vasilache replaceLinalgTransformationFilter(PatternRewriter &rewriter, 101307cfdf5SNicolas Vasilache Operation *op) const { 102307cfdf5SNicolas Vasilache if (replacement.hasValue()) 103307cfdf5SNicolas Vasilache op->setAttr(LinalgTransforms::kLinalgTransformMarker, 104195730a6SRiver Riddle replacement.getValue()); 105307cfdf5SNicolas Vasilache else 106195730a6SRiver Riddle op->removeAttr( 107195730a6SRiver Riddle rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); 108307cfdf5SNicolas Vasilache } 109307cfdf5SNicolas Vasilache 110d26beb0bSMaheshRavishankar bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( 111d26beb0bSMaheshRavishankar Operation *op) const { 112d26beb0bSMaheshRavishankar if (!replacement) 113d26beb0bSMaheshRavishankar return false; 114d26beb0bSMaheshRavishankar auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) 115d26beb0bSMaheshRavishankar .dyn_cast<StringAttr>(); 116d26beb0bSMaheshRavishankar return attr && attr == replacement.getValue(); 117d26beb0bSMaheshRavishankar } 118d26beb0bSMaheshRavishankar 119004a3d4fSNicolas Vasilache LinalgTilingOptions & 120004a3d4fSNicolas Vasilache mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 121fb1def9cSMatthias Springer assert(!tileSizeComputationFunction && "tile sizes already set"); 122004a3d4fSNicolas Vasilache SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 123004a3d4fSNicolas Vasilache tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 124004a3d4fSNicolas Vasilache OpBuilder::InsertionGuard guard(b); 125004a3d4fSNicolas Vasilache b.setInsertionPointToStart( 126004a3d4fSNicolas Vasilache &op->getParentOfType<FuncOp>().getBody().front()); 12791beb517SNicolas Vasilache return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 128a54f4eaeSMogball Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 129004a3d4fSNicolas Vasilache return v; 130004a3d4fSNicolas Vasilache })); 131004a3d4fSNicolas Vasilache }; 132004a3d4fSNicolas Vasilache return *this; 133e9ac7927SAlexander Belyaev } 134004a3d4fSNicolas Vasilache 135fb1def9cSMatthias Springer LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 136fb1def9cSMatthias Springer assert(!tileSizeComputationFunction && "tile sizes already set"); 137fb1def9cSMatthias Springer tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 138fb1def9cSMatthias Springer SmallVector<Value, 4> tileSizes; 139fb1def9cSMatthias Springer auto linalgOp = dyn_cast<LinalgOp>(op); 140fb1def9cSMatthias Springer if (!linalgOp) 141fb1def9cSMatthias Springer return tileSizes; 142fb1def9cSMatthias Springer Location loc = linalgOp.getLoc(); 143fb1def9cSMatthias Springer auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 144fb1def9cSMatthias Springer AffineMap map = linalgOp.getShapesToLoopsMap(); 145fb1def9cSMatthias Springer if (!map) 146fb1def9cSMatthias Springer return tileSizes; 147fb1def9cSMatthias Springer auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 148fb1def9cSMatthias Springer // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 149fb1def9cSMatthias Springer // size 0). 150fb1def9cSMatthias Springer for (Value shapeSize : shapeSizes) 151fb1def9cSMatthias Springer tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 152a54f4eaeSMogball ? b.create<arith::ConstantIndexOp>(loc, 0) 153a54f4eaeSMogball : b.create<arith::ConstantIndexOp>(loc, 1)); 154fb1def9cSMatthias Springer return tileSizes; 155fb1def9cSMatthias Springer }; 156fb1def9cSMatthias Springer return *this; 157fb1def9cSMatthias Springer } 158fb1def9cSMatthias Springer 159a4fd8cb7STobias Gysi /// Helper function that tries to pad `opOperand`. Exit early for scalar 160a4fd8cb7STobias Gysi /// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined 161a4fd8cb7STobias Gysi /// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already 162a4fd8cb7STobias Gysi /// has a static shape. Set `result` to the result of the created PadTensorOp or 163a4fd8cb7STobias Gysi /// and return success if the operand either has been padded to a static shape 164a4fd8cb7STobias Gysi /// or already had a static shape and failure otherwise. 1653747eb9cSNicolas Vasilache static LogicalResult padOperandToSmallestStaticBoundingBox( 1661eae247aSTobias Gysi OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 167a8f69be6STobias Gysi const PaddingValueComputationFunction &paddingFunc, 168a8f69be6STobias Gysi const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { 169a4fd8cb7STobias Gysi // Get the shape of the operand and check if it has a dynamic shape. Only 170a4fd8cb7STobias Gysi // return failure if the operand is not a scalar and has a dynamic shape. 171a4fd8cb7STobias Gysi ArrayRef<int64_t> shape = opToPad.getShape(opOperand); 172a4fd8cb7STobias Gysi bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize); 173a4fd8cb7STobias Gysi 174a4fd8cb7STobias Gysi // Cannot pad scalar operands. 175a4fd8cb7STobias Gysi if (shape.empty()) 1763747eb9cSNicolas Vasilache return success(); 177a4fd8cb7STobias Gysi 178a4fd8cb7STobias Gysi // Cannot pad if the padding value is unknown. 1791eae247aSTobias Gysi FailureOr<Value> paddingValue = paddingFunc(b, *opOperand); 180d20d0e14STobias Gysi if (failed(paddingValue)) 181a4fd8cb7STobias Gysi return failure(hasDynamicShape); 182a4fd8cb7STobias Gysi 183a4fd8cb7STobias Gysi // Cannot construct a static bounding box if the operand is not defined by an 184a4fd8cb7STobias Gysi // ExtractSliceOp. 185060208b4SMatthias Springer auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 186060208b4SMatthias Springer if (!sliceOp) 187a4fd8cb7STobias Gysi return failure(hasDynamicShape); 188a4fd8cb7STobias Gysi 189a4fd8cb7STobias Gysi // Upper bound the `sliceOp` sizes to obtain a static bounding box. 1903747eb9cSNicolas Vasilache SmallVector<int64_t> staticSizes; 191d2661c6cSTobias Gysi staticSizes.reserve(opToPad.getRank(opOperand)); 192060208b4SMatthias Springer auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 1933747eb9cSNicolas Vasilache for (auto size : shapedOp.getMixedSizes()) { 194ea53a693STobias Gysi // If the size is an attribute add it directly to `staticSizes`. 195ea53a693STobias Gysi if (size.is<Attribute>()) { 196ea53a693STobias Gysi staticSizes.push_back( 197ea53a693STobias Gysi size.get<Attribute>().dyn_cast<IntegerAttr>().getInt()); 198ea53a693STobias Gysi continue; 199ea53a693STobias Gysi } 200ea53a693STobias Gysi // Otherwise, try to compute a constant upper bound for the size value. 201ea53a693STobias Gysi FailureOr<int64_t> upperBound = 202ea53a693STobias Gysi getConstantUpperBoundForIndex(size.get<Value>()); 203ea53a693STobias Gysi if (failed(upperBound)) { 2041eae247aSTobias Gysi LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 2051eae247aSTobias Gysi return failure(); 2061eae247aSTobias Gysi } 207ea53a693STobias Gysi staticSizes.push_back(upperBound.getValue()); 2083747eb9cSNicolas Vasilache } 209a4fd8cb7STobias Gysi 210a4fd8cb7STobias Gysi // Pad the operand to the bounding box defined by `staticSizes`. 211d2661c6cSTobias Gysi auto staticTensorType = RankedTensorType::get( 212046922e1STobias Gysi staticSizes, getElementTypeOrSelf(opOperand->get())); 213a8f69be6STobias Gysi bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; 21486f186efSTobias Gysi result = 21586f186efSTobias Gysi makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType, 21686f186efSTobias Gysi opOperand->get(), paddingValue.getValue(), nofold); 2173747eb9cSNicolas Vasilache return success(); 2183747eb9cSNicolas Vasilache } 2193747eb9cSNicolas Vasilache 2201eae247aSTobias Gysi FailureOr<SmallVector<Value>> 2211eae247aSTobias Gysi linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 222aa2210a8SAlexander Belyaev const PaddingValueComputationFunction &paddingFunc, 223a8f69be6STobias Gysi const PaddingNoFoldComputationFunction &nofoldFunc, 224aa2210a8SAlexander Belyaev LinalgOp &paddedOp) { 2253747eb9cSNicolas Vasilache Location loc = opToPad->getLoc(); 2263747eb9cSNicolas Vasilache 2273747eb9cSNicolas Vasilache // TODO: there are cases where we may still want to pad to larger sizes. 2289f815cb5STobias Gysi assert(opToPad.hasTensorSemantics() && 2299f815cb5STobias Gysi "expected operation to have tensor semantics"); 2303747eb9cSNicolas Vasilache 2311eae247aSTobias Gysi OpBuilder::InsertionGuard g(b); 2323747eb9cSNicolas Vasilache // Set IP after op because we also take the dims of the original output. 2331eae247aSTobias Gysi b.setInsertionPointAfter(opToPad); 2343747eb9cSNicolas Vasilache // Make a copy of the shaped operands and update it. 235d57a305fSNicolas Vasilache SmallVector<Value> newOperands; 2369f815cb5STobias Gysi newOperands.reserve(opToPad.getNumInputsAndOutputs()); 2379f815cb5STobias Gysi for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 2383747eb9cSNicolas Vasilache Value paddedOperand; 2393747eb9cSNicolas Vasilache // If padding was requested but the shape cannot be bounded statically then 2403747eb9cSNicolas Vasilache // the pattern fails to apply. 2419f815cb5STobias Gysi if (failed(padOperandToSmallestStaticBoundingBox( 2421eae247aSTobias Gysi b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand))) 2433747eb9cSNicolas Vasilache return failure(); 2449f815cb5STobias Gysi newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 2453747eb9cSNicolas Vasilache } 2463747eb9cSNicolas Vasilache 247b01d223fSNicolas Vasilache SmallVector<SmallVector<Value>> reifiedResultShapes; 248b01d223fSNicolas Vasilache if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 2491eae247aSTobias Gysi .reifyResultShapes(b, reifiedResultShapes))) 250b01d223fSNicolas Vasilache return failure(); 251b01d223fSNicolas Vasilache assert(reifiedResultShapes.size() == opToPad->getNumResults() && 252b01d223fSNicolas Vasilache "expected same number of results"); 253b01d223fSNicolas Vasilache 2543747eb9cSNicolas Vasilache // Clone `opToPad` to operate on the statically padded shapes. 2553747eb9cSNicolas Vasilache auto resultTensorTypes = 256d57a305fSNicolas Vasilache ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 2571eae247aSTobias Gysi paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 2583747eb9cSNicolas Vasilache 259060208b4SMatthias Springer // Recover the slice out of the new static results. This keeps the original 260060208b4SMatthias Springer // linalg op around because it uses the dims of the original results. 2613747eb9cSNicolas Vasilache SmallVector<Value> paddedSubviewResults; 2623747eb9cSNicolas Vasilache paddedSubviewResults.reserve(opToPad->getNumResults()); 263b01d223fSNicolas Vasilache for (auto en : llvm::enumerate(paddedOp->getResults())) { 264b01d223fSNicolas Vasilache Value paddedResult = en.value(); 265b01d223fSNicolas Vasilache int64_t resultNumber = en.index(); 266b01d223fSNicolas Vasilache int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 2671eae247aSTobias Gysi SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 268b01d223fSNicolas Vasilache SmallVector<OpFoldResult> sizes; 269b01d223fSNicolas Vasilache for (Value v : reifiedResultShapes[resultNumber]) 270247a1a55STobias Gysi sizes.push_back(getAsOpFoldResult(v)); 2711eae247aSTobias Gysi SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 2721eae247aSTobias Gysi paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 273b01d223fSNicolas Vasilache loc, paddedResult, offsets, sizes, strides)); 2743747eb9cSNicolas Vasilache } 2751eae247aSTobias Gysi return paddedSubviewResults; 2763747eb9cSNicolas Vasilache } 2773747eb9cSNicolas Vasilache 278307cfdf5SNicolas Vasilache /// Linalg base tiling pattern. 279307cfdf5SNicolas Vasilache mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 280307cfdf5SNicolas Vasilache StringRef opName, MLIRContext *context, LinalgTilingOptions options, 281e4a503a2SNicolas Vasilache LinalgTransformationFilter filter, PatternBenefit benefit) 28276f3c2f3SRiver Riddle : RewritePattern(opName, benefit, context), filter(filter), 283307cfdf5SNicolas Vasilache options(options) {} 284307cfdf5SNicolas Vasilache 285a1cd559cSNicolas Vasilache mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 28676f3c2f3SRiver Riddle MLIRContext *context, LinalgTilingOptions options, 28776f3c2f3SRiver Riddle LinalgTransformationFilter filter, PatternBenefit benefit) 28876f3c2f3SRiver Riddle : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 289a1cd559cSNicolas Vasilache options(options) {} 290a1cd559cSNicolas Vasilache 2918faf35c0SMatthias Springer /// Try to peel a loop `op` and return the new result. 2922190f8a8SMatthias Springer // TODO: Add support for scf.parallel and affine.for loops. 2938faf35c0SMatthias Springer static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 2948faf35c0SMatthias Springer return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 2958faf35c0SMatthias Springer .Case<scf::ForOp>([&](scf::ForOp forOp) { 2968faf35c0SMatthias Springer scf::ForOp partialIteration; 2978faf35c0SMatthias Springer if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 2988faf35c0SMatthias Springer partialIteration))) 2998faf35c0SMatthias Springer return partialIteration->getResults(); 3008faf35c0SMatthias Springer assert(!partialIteration && "expected that loop was not peeled"); 3018faf35c0SMatthias Springer return forOp->getResults(); 3028faf35c0SMatthias Springer }) 3038faf35c0SMatthias Springer .Default([&](Operation *op) { return op->getResults(); }); 3048faf35c0SMatthias Springer } 3058faf35c0SMatthias Springer 3062190f8a8SMatthias Springer /// Try to peel a TiledLoopOp and return the new result. 3072190f8a8SMatthias Springer static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, 3082190f8a8SMatthias Springer TiledLoopOp tiledLoop, int64_t idx) { 3092190f8a8SMatthias Springer assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) && 3102190f8a8SMatthias Springer "requested peeling of non-existing loop"); 3112190f8a8SMatthias Springer TiledLoopOp result; 3122190f8a8SMatthias Springer if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result))) 3132190f8a8SMatthias Springer return result->getResults(); 3142190f8a8SMatthias Springer assert(!result && "expected that loop was not peeled"); 3152190f8a8SMatthias Springer return tiledLoop->getResults(); 3162190f8a8SMatthias Springer } 3172190f8a8SMatthias Springer 3182190f8a8SMatthias Springer /// Peel loops after tiling. 3192190f8a8SMatthias Springer static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res, 3202190f8a8SMatthias Springer const LinalgTilingOptions &options) { 3212190f8a8SMatthias Springer for (int64_t loop : options.peeledLoops) { 3222190f8a8SMatthias Springer assert(loop < static_cast<int64_t>(res.loops.size()) && 3232190f8a8SMatthias Springer "requested peeling of non-existing loop"); 3242190f8a8SMatthias Springer SmallVector<Value, 4> loopResults; 3252190f8a8SMatthias Springer Operation *loopOp = res.loops[loop]; 3262190f8a8SMatthias Springer if (options.loopType == LinalgTilingLoopType::TiledLoops) { 3272190f8a8SMatthias Springer assert(llvm::all_of( 3282190f8a8SMatthias Springer res.loops, 3292190f8a8SMatthias Springer [&](Operation *op) { return op == res.loops.front(); }) && 3302190f8a8SMatthias Springer "expected that all loop ops are the same TiledLoopOp"); 3312190f8a8SMatthias Springer auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp); 3322190f8a8SMatthias Springer assert(tiledLoopOp && "expected TiledLoopOp"); 3332190f8a8SMatthias Springer loopResults = peelLoop(rewriter, tiledLoopOp, loop); 3342190f8a8SMatthias Springer } else { 3352190f8a8SMatthias Springer loopResults = peelLoop(rewriter, loopOp); 3362190f8a8SMatthias Springer } 3372190f8a8SMatthias Springer 3382190f8a8SMatthias Springer // The result of the loop nest may change with peeling. 3392190f8a8SMatthias Springer if (res.tensorResults.size() == loopOp->getNumResults() && 3402190f8a8SMatthias Springer std::equal(res.tensorResults.begin(), res.tensorResults.end(), 3412190f8a8SMatthias Springer loopOp->getResults().begin())) 3422190f8a8SMatthias Springer res.tensorResults = loopResults; 3432190f8a8SMatthias Springer } 3442190f8a8SMatthias Springer } 3452190f8a8SMatthias Springer 346a3adcba6SNicolas Vasilache LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 34780f07854SNicolas Vasilache Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 348307cfdf5SNicolas Vasilache LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 3498ea5d190STobias Gysi if (!linalgOp) 350307cfdf5SNicolas Vasilache return failure(); 351e4a503a2SNicolas Vasilache if (failed(filter.checkAndNotify(rewriter, linalgOp))) 352307cfdf5SNicolas Vasilache return failure(); 35356ce65e2SNicolas Vasilache 354004a3d4fSNicolas Vasilache Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 355307cfdf5SNicolas Vasilache 356307cfdf5SNicolas Vasilache if (!res) 357307cfdf5SNicolas Vasilache return failure(); 3588faf35c0SMatthias Springer // Clear filter to stop recursive pattern application. 3598faf35c0SMatthias Springer filter.replaceLinalgTransformationFilter(rewriter, res->op); 360307cfdf5SNicolas Vasilache 3618faf35c0SMatthias Springer // Peel loops. 3622190f8a8SMatthias Springer peelLoops(rewriter, *res, options); 363a3adcba6SNicolas Vasilache 3648faf35c0SMatthias Springer result = *res; 3653747eb9cSNicolas Vasilache return success(); 3668faf35c0SMatthias Springer } 3673747eb9cSNicolas Vasilache 36898835e3dSMaheshRavishankar static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 36998835e3dSMaheshRavishankar if (tiledOp.loops.empty()) 37098835e3dSMaheshRavishankar return tiledOp.op.getOperation()->getResults(); 37198835e3dSMaheshRavishankar return tiledOp.loops.front()->getResults(); 37298835e3dSMaheshRavishankar } 37398835e3dSMaheshRavishankar 37498835e3dSMaheshRavishankar static ValueRange 37598835e3dSMaheshRavishankar getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 37698835e3dSMaheshRavishankar if (tiledAndFusedOp.fusedLoops.empty()) 37798835e3dSMaheshRavishankar return tiledAndFusedOp.op.getOperation()->getResults(); 37898835e3dSMaheshRavishankar return tiledAndFusedOp.fusedLoops.front()->getResults(); 37998835e3dSMaheshRavishankar } 38098835e3dSMaheshRavishankar 381c694588fSMaheshRavishankar mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 382c694588fSMaheshRavishankar StringRef opName, MLIRContext *context, 383c694588fSMaheshRavishankar const LinalgDependenceGraph &dependenceGraph, 384c694588fSMaheshRavishankar LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 385e4a503a2SNicolas Vasilache LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 386299cc5daSNicolas Vasilache LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 38776f3c2f3SRiver Riddle : RewritePattern(opName, benefit, context, {}), 388c694588fSMaheshRavishankar dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 389e4a503a2SNicolas Vasilache fusionOptions(fusionOptions), filter(filter), 390c694588fSMaheshRavishankar fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 391c694588fSMaheshRavishankar 392c694588fSMaheshRavishankar LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 393c694588fSMaheshRavishankar Operation *op, PatternRewriter &rewriter) const { 394c694588fSMaheshRavishankar LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 39593f9922dSTobias Gysi // TODO: remove hasIndexSemantics check once index ops are supported. 39693f9922dSTobias Gysi if (!linalgOp || linalgOp.hasIndexSemantics()) 397c694588fSMaheshRavishankar return failure(); 398e4a503a2SNicolas Vasilache if (failed(filter.checkAndNotify(rewriter, linalgOp))) 399c694588fSMaheshRavishankar return failure(); 400c694588fSMaheshRavishankar 401e65a5e5bSMaheshRavishankar DenseSet<Operation *> producers; 402e65a5e5bSMaheshRavishankar producers.insert(linalgOp); 403bce318f5SMaheshRavishankar for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 404bce318f5SMaheshRavishankar Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 405bce318f5SMaheshRavishankar // When looking at dependences into, indexingOp is always OpOperand. We 406bce318f5SMaheshRavishankar // could assert, but continue if this is not the case. 407bce318f5SMaheshRavishankar if (!operandNumber) 408e65a5e5bSMaheshRavishankar continue; 409bce318f5SMaheshRavishankar if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 410bce318f5SMaheshRavishankar continue; 411bce318f5SMaheshRavishankar if (isa<LinalgOp>(dependence.getDependentOp())) 412bce318f5SMaheshRavishankar producers.insert(dependence.getDependentOp()); 413e65a5e5bSMaheshRavishankar } 414e65a5e5bSMaheshRavishankar 415e65a5e5bSMaheshRavishankar SmallVector<LinalgOp, 1> fusionOps; 416e65a5e5bSMaheshRavishankar for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 417e65a5e5bSMaheshRavishankar ++it) { 418e65a5e5bSMaheshRavishankar auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 419e65a5e5bSMaheshRavishankar if (producerLinalgOp && producers.count(producerLinalgOp)) 420e65a5e5bSMaheshRavishankar fusionOps.push_back(producerLinalgOp); 421e65a5e5bSMaheshRavishankar } 422e65a5e5bSMaheshRavishankar fusionOps.push_back(linalgOp); 423e65a5e5bSMaheshRavishankar 424e65a5e5bSMaheshRavishankar SmallVector<Value, 4> tileSizes = 425e65a5e5bSMaheshRavishankar tilingOptions.tileSizeComputationFunction(rewriter, op); 426e65a5e5bSMaheshRavishankar LinalgTilingOptions instanceTilingOptions = tilingOptions; 427e65a5e5bSMaheshRavishankar instanceTilingOptions.setTileSizes(tileSizes); 428c694588fSMaheshRavishankar Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 429e65a5e5bSMaheshRavishankar rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 430c694588fSMaheshRavishankar if (!tiledAndFusedOps) 431c694588fSMaheshRavishankar return failure(); 432e65a5e5bSMaheshRavishankar 433e65a5e5bSMaheshRavishankar // Tile the unfused loops; 434e65a5e5bSMaheshRavishankar SmallVector<Value, 4> unfusedLoopTileSizes; 435a54f4eaeSMogball Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 436e65a5e5bSMaheshRavishankar for (auto tileSize : enumerate(tileSizes)) { 437e65a5e5bSMaheshRavishankar if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 438e65a5e5bSMaheshRavishankar unfusedLoopTileSizes.push_back(zero); 439e65a5e5bSMaheshRavishankar else 440e65a5e5bSMaheshRavishankar unfusedLoopTileSizes.push_back(tileSize.value()); 441e65a5e5bSMaheshRavishankar } 442e65a5e5bSMaheshRavishankar // Tile the loop only if there is a non-zero tile size. 443e65a5e5bSMaheshRavishankar if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 444e65a5e5bSMaheshRavishankar unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 445e65a5e5bSMaheshRavishankar if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 446a54f4eaeSMogball if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 447a54f4eaeSMogball return cst.value() != 0; 448e65a5e5bSMaheshRavishankar return true; 449e65a5e5bSMaheshRavishankar })) { 450e65a5e5bSMaheshRavishankar LinalgTilingOptions unfusedTilingOptions = tilingOptions; 451e65a5e5bSMaheshRavishankar unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 452e65a5e5bSMaheshRavishankar Optional<TiledLinalgOp> unfusedTiledOp = 453e65a5e5bSMaheshRavishankar tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 454e65a5e5bSMaheshRavishankar if (!unfusedTiledOp) 455e65a5e5bSMaheshRavishankar return failure(); 45698835e3dSMaheshRavishankar rewriter.replaceOp(tiledAndFusedOps->op, 45798835e3dSMaheshRavishankar getTiledOpResult(unfusedTiledOp.getValue())); 458e65a5e5bSMaheshRavishankar tiledAndFusedOps->op = unfusedTiledOp->op; 459e65a5e5bSMaheshRavishankar } 46098835e3dSMaheshRavishankar op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 461e65a5e5bSMaheshRavishankar 462e4a503a2SNicolas Vasilache filter.replaceLinalgTransformationFilter(rewriter, 463299cc5daSNicolas Vasilache tiledAndFusedOps->op.getOperation()); 464c694588fSMaheshRavishankar for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 465299cc5daSNicolas Vasilache fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 466299cc5daSNicolas Vasilache fusedOp.getOperation()); 467c694588fSMaheshRavishankar } 468e65a5e5bSMaheshRavishankar for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 469299cc5daSNicolas Vasilache originalOpMarker.replaceLinalgTransformationFilter( 470299cc5daSNicolas Vasilache rewriter, origProducerOp.getOperation()); 471e65a5e5bSMaheshRavishankar } 472299cc5daSNicolas Vasilache rewriter.updateRootInPlace(op, [&]() { 473299cc5daSNicolas Vasilache originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 474299cc5daSNicolas Vasilache }); 475c694588fSMaheshRavishankar return success(); 476c694588fSMaheshRavishankar } 477c694588fSMaheshRavishankar 478d0ec4a8eSTobias Gysi /// Linalg padding pattern. 479d0ec4a8eSTobias Gysi mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 480d0ec4a8eSTobias Gysi MLIRContext *context, LinalgPaddingOptions options, 481d0ec4a8eSTobias Gysi LinalgTransformationFilter filter, PatternBenefit benefit) 482d0ec4a8eSTobias Gysi : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 483d0ec4a8eSTobias Gysi options(options) {} 484d0ec4a8eSTobias Gysi 485d0ec4a8eSTobias Gysi mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 486d0ec4a8eSTobias Gysi StringRef opName, MLIRContext *context, LinalgPaddingOptions options, 487d0ec4a8eSTobias Gysi LinalgTransformationFilter filter, PatternBenefit benefit) 488d0ec4a8eSTobias Gysi : RewritePattern(opName, benefit, context, {}), filter(filter), 489d0ec4a8eSTobias Gysi options(options) {} 490d0ec4a8eSTobias Gysi 491d0ec4a8eSTobias Gysi LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite( 492d0ec4a8eSTobias Gysi Operation *op, PatternRewriter &rewriter) const { 493d0ec4a8eSTobias Gysi LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 494d0ec4a8eSTobias Gysi if (!linalgOp) 495d0ec4a8eSTobias Gysi return failure(); 496d0ec4a8eSTobias Gysi if (!linalgOp.hasTensorSemantics()) 497d0ec4a8eSTobias Gysi return failure(); 498d0ec4a8eSTobias Gysi if (failed(filter.checkAndNotify(rewriter, op))) 499d0ec4a8eSTobias Gysi return failure(); 500d0ec4a8eSTobias Gysi 501d0ec4a8eSTobias Gysi // Pad the operation. 502d0ec4a8eSTobias Gysi LinalgOp paddedOp; 503d0ec4a8eSTobias Gysi FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp( 504d0ec4a8eSTobias Gysi rewriter, linalgOp, options.paddingValueComputationFunction, 505d0ec4a8eSTobias Gysi options.paddingNoFoldComputationFunction, paddedOp); 50669bcff46Sgysit if (failed(newResults)) 507d0ec4a8eSTobias Gysi return failure(); 508d0ec4a8eSTobias Gysi 509d0ec4a8eSTobias Gysi // Compute the desired hoisting depths. 510d0ec4a8eSTobias Gysi SmallVector<int64_t> depths; 511d0ec4a8eSTobias Gysi if (options.paddingHoistComputationFunction) { 512d0ec4a8eSTobias Gysi for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) 513d0ec4a8eSTobias Gysi depths.push_back(options.paddingHoistComputationFunction(*opOperand)); 514d0ec4a8eSTobias Gysi } 515d0ec4a8eSTobias Gysi 516d0ec4a8eSTobias Gysi // Hoist the padding. 517d0ec4a8eSTobias Gysi for (auto en : enumerate(depths)) { 518d0ec4a8eSTobias Gysi OpOperand &opOperand = paddedOp->getOpOperand(en.index()); 519d0ec4a8eSTobias Gysi auto padTensorOp = opOperand.get().getDefiningOp<PadTensorOp>(); 520d0ec4a8eSTobias Gysi if (!padTensorOp || en.value() == 0) 521d0ec4a8eSTobias Gysi continue; 522d0ec4a8eSTobias Gysi PadTensorOp hoistedOp; 523d0ec4a8eSTobias Gysi FailureOr<Value> newResult = 524d0ec4a8eSTobias Gysi hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp); 525d0ec4a8eSTobias Gysi if (failed(newResult)) 526d0ec4a8eSTobias Gysi continue; 527d0ec4a8eSTobias Gysi rewriter.replaceOp(padTensorOp, newResult.getValue()); 528d0ec4a8eSTobias Gysi } 529d0ec4a8eSTobias Gysi 530d0ec4a8eSTobias Gysi // Replace the original operation to pad. 531d0ec4a8eSTobias Gysi rewriter.replaceOp(op, newResults.getValue()); 532d0ec4a8eSTobias Gysi filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 533d0ec4a8eSTobias Gysi return success(); 534d0ec4a8eSTobias Gysi } 535d0ec4a8eSTobias Gysi 536e3d386eaSTobias Gysi /// Linalg tile and fuse tensor ops pattern. 537e3d386eaSTobias Gysi mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 538e3d386eaSTobias Gysi LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, 539e3d386eaSTobias Gysi LinalgTilingAndFusionOptions options, 540e3d386eaSTobias Gysi LinalgTransformationFilter filter, 541e3d386eaSTobias Gysi PatternBenefit benefit) 542e3d386eaSTobias Gysi : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 543e3d386eaSTobias Gysi options(options) {} 544e3d386eaSTobias Gysi 545e3d386eaSTobias Gysi mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 546e3d386eaSTobias Gysi LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, 547e3d386eaSTobias Gysi LinalgTilingAndFusionOptions options, 548e3d386eaSTobias Gysi LinalgTransformationFilter filter, 549e3d386eaSTobias Gysi PatternBenefit benefit) 550e3d386eaSTobias Gysi : RewritePattern(opName, benefit, context), filter(filter), 551e3d386eaSTobias Gysi options(options) {} 552e3d386eaSTobias Gysi 553e3d386eaSTobias Gysi LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite( 554e3d386eaSTobias Gysi Operation *op, PatternRewriter &rewriter) const { 555e3d386eaSTobias Gysi LinalgOp rootOp = dyn_cast<LinalgOp>(op); 556e3d386eaSTobias Gysi if (!rootOp) 557e3d386eaSTobias Gysi return failure(); 558e3d386eaSTobias Gysi if (failed(filter.checkAndNotify(rewriter, op))) 559e3d386eaSTobias Gysi return failure(); 560e3d386eaSTobias Gysi 561e3d386eaSTobias Gysi // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. 562e3d386eaSTobias Gysi if (options.tileSizes.size() < rootOp.getNumLoops()) 563e3d386eaSTobias Gysi return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); 564e3d386eaSTobias Gysi 565e3d386eaSTobias Gysi // Check `tileInterchange` contains no entries or as many as `tileSizes`. 566e3d386eaSTobias Gysi if (!options.tileInterchange.empty() && 567e3d386eaSTobias Gysi options.tileInterchange.size() != options.tileSizes.size()) 568e3d386eaSTobias Gysi return rewriter.notifyMatchFailure( 569e3d386eaSTobias Gysi op, "expect the number of tile sizes and interchange dims to match"); 570e3d386eaSTobias Gysi 571e3d386eaSTobias Gysi // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. 572e3d386eaSTobias Gysi SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(), 573e3d386eaSTobias Gysi options.tileSizes.begin() + 574e3d386eaSTobias Gysi rootOp.getNumLoops()); 575e3d386eaSTobias Gysi SmallVector<int64_t> rootInterchange = 576e3d386eaSTobias Gysi options.tileInterchange.empty() 577e3d386eaSTobias Gysi ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops())) 578e3d386eaSTobias Gysi : SmallVector<int64_t>(options.tileInterchange.begin(), 579e3d386eaSTobias Gysi options.tileInterchange.begin() + 580e3d386eaSTobias Gysi rootOp.getNumLoops()); 581e3d386eaSTobias Gysi 582e3d386eaSTobias Gysi // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. 583e3d386eaSTobias Gysi // It has to be a permutation since the tiling cannot tile the same loop 584e3d386eaSTobias Gysi // dimension multiple times. 585e3d386eaSTobias Gysi if (!isPermutation(rootInterchange)) 586e3d386eaSTobias Gysi return rewriter.notifyMatchFailure( 587e3d386eaSTobias Gysi op, "expect the tile interchange permutes the root loops"); 588e3d386eaSTobias Gysi 589e3d386eaSTobias Gysi // Tile `rootOp` and fuse its producers. 590e3d386eaSTobias Gysi FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers( 591e3d386eaSTobias Gysi rewriter, rootOp, rootTileSizes, rootInterchange); 592e3d386eaSTobias Gysi if (failed(tileLoopNest)) 593e3d386eaSTobias Gysi return rewriter.notifyMatchFailure( 594e3d386eaSTobias Gysi op, "tileConsumerAndFuseProducers failed unexpectedly"); 595e3d386eaSTobias Gysi 596e3d386eaSTobias Gysi // Replace all uses of the tiled loop operation. 597e3d386eaSTobias Gysi rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); 598e3d386eaSTobias Gysi 599e3d386eaSTobias Gysi // Apply the filter if specified. 600e3d386eaSTobias Gysi for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) 601e3d386eaSTobias Gysi filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 602e3d386eaSTobias Gysi return failure(); 603e3d386eaSTobias Gysi } 604e3d386eaSTobias Gysi 60506bb9cf3STobias Gysi /// Linalg generic interchange pattern. 60606bb9cf3STobias Gysi mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 60706bb9cf3STobias Gysi MLIRContext *context, ArrayRef<unsigned> interchangeVector, 60806bb9cf3STobias Gysi LinalgTransformationFilter filter, PatternBenefit benefit) 60906bb9cf3STobias Gysi : OpRewritePattern(context, benefit), filter(filter), 610307cfdf5SNicolas Vasilache interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 611307cfdf5SNicolas Vasilache 61206bb9cf3STobias Gysi LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 61306bb9cf3STobias Gysi GenericOp genericOp, PatternRewriter &rewriter) const { 61406bb9cf3STobias Gysi if (failed(filter.checkAndNotify(rewriter, genericOp))) 615307cfdf5SNicolas Vasilache return failure(); 61606bb9cf3STobias Gysi if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 617307cfdf5SNicolas Vasilache return failure(); 618307cfdf5SNicolas Vasilache 619307cfdf5SNicolas Vasilache // TODO: figure out how this interplays with named ops. In particular this 620307cfdf5SNicolas Vasilache // should break the named op property. 62106bb9cf3STobias Gysi rewriter.updateRootInPlace(genericOp, [&]() { 62206bb9cf3STobias Gysi interchangeGenericOp(rewriter, genericOp, interchangeVector); 623e4a503a2SNicolas Vasilache // New filter if specified. 62406bb9cf3STobias Gysi filter.replaceLinalgTransformationFilter(rewriter, genericOp); 625307cfdf5SNicolas Vasilache }); 626307cfdf5SNicolas Vasilache return success(); 627307cfdf5SNicolas Vasilache } 628307cfdf5SNicolas Vasilache 629e826db62STobias Gysi /// Linalg generalization pattern. 630e826db62STobias Gysi mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 631e826db62STobias Gysi MLIRContext *context, LinalgTransformationFilter filter, 632e826db62STobias Gysi PatternBenefit benefit) 633e826db62STobias Gysi : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 634e826db62STobias Gysi 635e826db62STobias Gysi mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 636e826db62STobias Gysi StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 637e826db62STobias Gysi PatternBenefit benefit) 638e826db62STobias Gysi : RewritePattern(opName, benefit, context, {}), filter(filter) {} 639e826db62STobias Gysi 640e826db62STobias Gysi LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite( 641e826db62STobias Gysi Operation *op, PatternRewriter &rewriter) const { 642e826db62STobias Gysi if (failed(filter.checkAndNotify(rewriter, op))) 643e826db62STobias Gysi return failure(); 644e826db62STobias Gysi if (failed(generalizeNamedOpPrecondition(op))) 645e826db62STobias Gysi return failure(); 646e826db62STobias Gysi 647e826db62STobias Gysi GenericOp genericOp = generalizeNamedOp(rewriter, op); 648e826db62STobias Gysi rewriter.replaceOp(op, genericOp.getResults()); 649e826db62STobias Gysi filter.replaceLinalgTransformationFilter(rewriter, genericOp); 650e826db62STobias Gysi return success(); 651e826db62STobias Gysi } 652e826db62STobias Gysi 653307cfdf5SNicolas Vasilache mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 65492ea624aSNicolas Vasilache MLIRContext *context, LinalgTransformationFilter filter, 65592ea624aSNicolas Vasilache LinalgPromotionOptions options, PatternBenefit benefit) 65692ea624aSNicolas Vasilache : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 65792ea624aSNicolas Vasilache options(options) {} 65892ea624aSNicolas Vasilache 65992ea624aSNicolas Vasilache mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 6608dbbb223SNicolas Vasilache StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 661e4a503a2SNicolas Vasilache LinalgTransformationFilter filter, PatternBenefit benefit) 66276f3c2f3SRiver Riddle : RewritePattern(opName, benefit, context, {}), filter(filter), 6638dbbb223SNicolas Vasilache options(options) {} 664307cfdf5SNicolas Vasilache 665307cfdf5SNicolas Vasilache LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 666307cfdf5SNicolas Vasilache Operation *op, PatternRewriter &rewriter) const { 667e4a503a2SNicolas Vasilache if (failed(filter.checkAndNotify(rewriter, op))) 668307cfdf5SNicolas Vasilache return failure(); 6698dbbb223SNicolas Vasilache if (failed(promoteSubviewsPrecondition(op, options))) 670307cfdf5SNicolas Vasilache return failure(); 6710ed2d4c7SMaheshRavishankar 6720ed2d4c7SMaheshRavishankar // TODO: We cannot use root update here. This pattern is creating other ops, 6730ed2d4c7SMaheshRavishankar // so if the promotion fails, those need to be cleaned up, which doesnt seem 6740ed2d4c7SMaheshRavishankar // to be happening here. So to fail properly, we should be cloning the op and 6750ed2d4c7SMaheshRavishankar // deleting the previous op. This needs more investigation. 6760ed2d4c7SMaheshRavishankar rewriter.startRootUpdate(op); 6770ed2d4c7SMaheshRavishankar Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 6780ed2d4c7SMaheshRavishankar if (!promotedOp) { 6790ed2d4c7SMaheshRavishankar rewriter.cancelRootUpdate(op); 6800ed2d4c7SMaheshRavishankar return op->emitError("subview promotion failed"); 6810ed2d4c7SMaheshRavishankar } 6820ed2d4c7SMaheshRavishankar rewriter.finalizeRootUpdate(op); 683e4a503a2SNicolas Vasilache filter.replaceLinalgTransformationFilter(rewriter, op); 684307cfdf5SNicolas Vasilache return success(); 685307cfdf5SNicolas Vasilache } 686307cfdf5SNicolas Vasilache 687307cfdf5SNicolas Vasilache mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 68876f3c2f3SRiver Riddle MLIRContext *context, LinalgTransformationFilter filter, 68976f3c2f3SRiver Riddle PatternBenefit benefit) 69076f3c2f3SRiver Riddle : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 691e4a503a2SNicolas Vasilache 692e4a503a2SNicolas Vasilache mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 693e4a503a2SNicolas Vasilache StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 694307cfdf5SNicolas Vasilache PatternBenefit benefit) 69576f3c2f3SRiver Riddle : RewritePattern(opName, benefit, context, {}), filter(filter) {} 696307cfdf5SNicolas Vasilache 697307cfdf5SNicolas Vasilache LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 698307cfdf5SNicolas Vasilache Operation *op, PatternRewriter &rewriter) const { 699307cfdf5SNicolas Vasilache LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 700b9715156STobias Gysi if (!linalgOp) 701307cfdf5SNicolas Vasilache return failure(); 702e4a503a2SNicolas Vasilache if (failed(filter.checkAndNotify(rewriter, linalgOp))) 703307cfdf5SNicolas Vasilache return failure(); 704c1a4cd55STobias Gysi SmallVector<Value> newResults; 705c1a4cd55STobias Gysi if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 706307cfdf5SNicolas Vasilache return failure(); 707c1a4cd55STobias Gysi if (!newResults.empty()) 708c1a4cd55STobias Gysi rewriter.replaceOp(op, newResults); 7090fcbbde2SNicolas Vasilache else 710307cfdf5SNicolas Vasilache rewriter.eraseOp(op); 711307cfdf5SNicolas Vasilache return success(); 712307cfdf5SNicolas Vasilache } 713d12d05a7SNicolas Vasilache 714d12d05a7SNicolas Vasilache LogicalResult mlir::linalg::applyStagedPatterns( 71579d7f618SChris Lattner Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 71679d7f618SChris Lattner const FrozenRewritePatternSet &stage2Patterns, 71791beb517SNicolas Vasilache function_ref<LogicalResult(Operation *)> stage3Lambda) { 71891beb517SNicolas Vasilache unsigned iteration = 0; 71991beb517SNicolas Vasilache (void)iteration; 720d12d05a7SNicolas Vasilache for (const auto &patterns : stage1Patterns) { 72156ce65e2SNicolas Vasilache LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 72256ce65e2SNicolas Vasilache << *op); 7233e98fbf4SRiver Riddle if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 72456ce65e2SNicolas Vasilache LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 725d12d05a7SNicolas Vasilache return failure(); 726d12d05a7SNicolas Vasilache } 72756ce65e2SNicolas Vasilache LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 72891beb517SNicolas Vasilache << *op); 7293e98fbf4SRiver Riddle if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 73056ce65e2SNicolas Vasilache LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 731d12d05a7SNicolas Vasilache return failure(); 732d12d05a7SNicolas Vasilache } 73356ce65e2SNicolas Vasilache LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 73491beb517SNicolas Vasilache << *op); 735d12d05a7SNicolas Vasilache if (stage3Lambda) { 736d12d05a7SNicolas Vasilache if (failed(stage3Lambda(op))) 737d12d05a7SNicolas Vasilache return failure(); 73856ce65e2SNicolas Vasilache LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 73991beb517SNicolas Vasilache << *op); 740d12d05a7SNicolas Vasilache } 741d12d05a7SNicolas Vasilache } 742d12d05a7SNicolas Vasilache return success(); 743d12d05a7SNicolas Vasilache } 7443110e7b0SNicolas Vasilache 7450804a88eSNicolas Agostini static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 7460804a88eSNicolas Agostini return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 7470804a88eSNicolas Agostini } 7480804a88eSNicolas Agostini 7490804a88eSNicolas Agostini /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 7500804a88eSNicolas Agostini /// with pad_val) and GenericOp (to copy contents). 7510804a88eSNicolas Agostini LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 7520804a88eSNicolas Agostini linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 7530804a88eSNicolas Agostini 7540804a88eSNicolas Agostini auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 7550804a88eSNicolas Agostini auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 7560804a88eSNicolas Agostini 7570804a88eSNicolas Agostini // Bail on non-static shapes. 7580804a88eSNicolas Agostini if (!inputShapedType.hasStaticShape()) 7590804a88eSNicolas Agostini return failure(); 7600804a88eSNicolas Agostini if (!resultShapedType.hasStaticShape()) 7610804a88eSNicolas Agostini return failure(); 7620804a88eSNicolas Agostini 7630804a88eSNicolas Agostini // Only support padding with a constant for now, i.e. either: 7640804a88eSNicolas Agostini // 1. A BBarg from a different block. 7650804a88eSNicolas Agostini // 2. A value defined outside of the current block. 7660804a88eSNicolas Agostini Block &block = padOp.region().front(); 7670804a88eSNicolas Agostini auto yieldOp = cast<YieldOp>(block.getTerminator()); 7680804a88eSNicolas Agostini assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 7690804a88eSNicolas Agostini Value padValue = yieldOp.values().front(); 7700804a88eSNicolas Agostini Operation *definingOp = padValue.getDefiningOp(); 7710804a88eSNicolas Agostini if (definingOp && definingOp->getBlock() == &block) 7720804a88eSNicolas Agostini return failure(); 7730804a88eSNicolas Agostini if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 7740804a88eSNicolas Agostini return failure(); 7750804a88eSNicolas Agostini 7760804a88eSNicolas Agostini // Create tensor with the padded shape 7770804a88eSNicolas Agostini Location loc = padOp.getLoc(); 7780804a88eSNicolas Agostini SmallVector<Value> indices(resultShapedType.getRank(), 779a54f4eaeSMogball rewriter.create<arith::ConstantIndexOp>(loc, 0)); 7800804a88eSNicolas Agostini Value initTensor = rewriter.create<InitTensorOp>( 7810804a88eSNicolas Agostini loc, resultShapedType.getShape(), resultShapedType.getElementType()); 7820804a88eSNicolas Agostini 7830804a88eSNicolas Agostini // Initialize tensor with the pad value 7840804a88eSNicolas Agostini Value tmpTensor = 7857cef24eeSTobias Gysi rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 7860804a88eSNicolas Agostini 7870804a88eSNicolas Agostini // Copy original contents into new tensor 788060208b4SMatthias Springer // Uses linalg.generic, but could be done with tensor.insert_slice 7890804a88eSNicolas Agostini SmallVector<AffineExpr, 4> outputExprs; 7900804a88eSNicolas Agostini for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 7910804a88eSNicolas Agostini outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 7920804a88eSNicolas Agostini padOp.static_low()[i].cast<IntegerAttr>().getInt()); 7930804a88eSNicolas Agostini } 7940804a88eSNicolas Agostini 7950804a88eSNicolas Agostini SmallVector<AffineMap, 2> transferMaps = { 7960804a88eSNicolas Agostini rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 7970804a88eSNicolas Agostini AffineMap::get(resultShapedType.getRank(), 7980804a88eSNicolas Agostini /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 7990804a88eSNicolas Agostini 8000804a88eSNicolas Agostini rewriter.replaceOpWithNewOp<linalg::GenericOp>( 8010804a88eSNicolas Agostini padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 8020804a88eSNicolas Agostini getNParallelLoopsAttrs(resultShapedType.getRank()), 8030804a88eSNicolas Agostini [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 8040804a88eSNicolas Agostini nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 8050804a88eSNicolas Agostini }); 8060804a88eSNicolas Agostini 8070804a88eSNicolas Agostini return success(); 8080804a88eSNicolas Agostini } 80924199f53SMatthias Springer 81035df2f6fSYi Zhang /// Filling `dest` using FillOp constant padding value if possible. 81135df2f6fSYi Zhang /// Otherwise, generate a tensor::GenerateOp. 81235df2f6fSYi Zhang Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 81335df2f6fSYi Zhang PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 81435df2f6fSYi Zhang const SmallVector<Value> &dynSizes) const { 81535df2f6fSYi Zhang auto padValue = padOp.getConstantPaddingValue(); 81635df2f6fSYi Zhang if (padValue) 81735df2f6fSYi Zhang return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 81835df2f6fSYi Zhang 81935df2f6fSYi Zhang // Fill could not be optimized: Lower to tensor::GenerateOp with region. 82035df2f6fSYi Zhang auto generateOp = rewriter.create<tensor::GenerateOp>( 82135df2f6fSYi Zhang padOp.getLoc(), padOp.getResultType(), dynSizes); 82235df2f6fSYi Zhang // Copy region to new op. 82335df2f6fSYi Zhang BlockAndValueMapping bvm; 82435df2f6fSYi Zhang padOp.region().cloneInto(&generateOp.getRegion(), bvm); 82535df2f6fSYi Zhang // Rewrite linalg::YieldOp to tensor::YieldOp. 82635df2f6fSYi Zhang OpBuilder::InsertionGuard guard(rewriter); 82735df2f6fSYi Zhang auto yieldOp = 82835df2f6fSYi Zhang dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 82935df2f6fSYi Zhang assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 83035df2f6fSYi Zhang assert(yieldOp.values().size() == 1); 83135df2f6fSYi Zhang rewriter.setInsertionPoint(yieldOp); 83235df2f6fSYi Zhang rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 83335df2f6fSYi Zhang return generateOp; 83435df2f6fSYi Zhang } 83535df2f6fSYi Zhang 83635df2f6fSYi Zhang LogicalResult 83735df2f6fSYi Zhang GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 83835df2f6fSYi Zhang PatternRewriter &rewriter) const { 83935df2f6fSYi Zhang // Given an OpFoldResult, return an index-typed value. 84035df2f6fSYi Zhang auto getIdxValue = [&](OpFoldResult ofr) { 84135df2f6fSYi Zhang if (auto val = ofr.dyn_cast<Value>()) 84235df2f6fSYi Zhang return val; 84335df2f6fSYi Zhang return rewriter 844a54f4eaeSMogball .create<arith::ConstantIndexOp>( 84535df2f6fSYi Zhang padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 84635df2f6fSYi Zhang .getResult(); 84735df2f6fSYi Zhang }; 84835df2f6fSYi Zhang 84935df2f6fSYi Zhang auto resultType = padOp.getResultType(); 85035df2f6fSYi Zhang // Compute size of InitTensorOp. Any combination of static/dynamic is 85135df2f6fSYi Zhang // supported. 85235df2f6fSYi Zhang SmallVector<Value> dynSizes; 85335df2f6fSYi Zhang SmallVector<int64_t> staticSizes; 85435df2f6fSYi Zhang for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 85535df2f6fSYi Zhang if (resultType.isDynamicDim(dim)) { 85635df2f6fSYi Zhang auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 85735df2f6fSYi Zhang padOp.source(), dim); 85835df2f6fSYi Zhang // Add low and high padding value. 859a54f4eaeSMogball auto plusLow = rewriter.createOrFold<arith::AddIOp>( 86035df2f6fSYi Zhang padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 861a54f4eaeSMogball auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 86235df2f6fSYi Zhang padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 86335df2f6fSYi Zhang dynSizes.push_back(plusHigh); 86435df2f6fSYi Zhang } 86535df2f6fSYi Zhang staticSizes.push_back(resultType.getDimSize(dim)); 86635df2f6fSYi Zhang } 86735df2f6fSYi Zhang 86835df2f6fSYi Zhang // Init tensor and fill it with padding. 86935df2f6fSYi Zhang Value init = rewriter.create<InitTensorOp>( 87035df2f6fSYi Zhang padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 87135df2f6fSYi Zhang Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 87235df2f6fSYi Zhang 87335df2f6fSYi Zhang // Try optimize the copy of source. 87435df2f6fSYi Zhang if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 87535df2f6fSYi Zhang return success(); 87635df2f6fSYi Zhang 87735df2f6fSYi Zhang // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 87835df2f6fSYi Zhang // for copying the PadOp source. 87935df2f6fSYi Zhang auto sourceType = padOp.getSourceType(); 88035df2f6fSYi Zhang // Compute size of source of PadTensorOp. 88135df2f6fSYi Zhang SmallVector<OpFoldResult> srcSizes; 88235df2f6fSYi Zhang for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 88335df2f6fSYi Zhang if (sourceType.isDynamicDim(dim)) { 88435df2f6fSYi Zhang srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 88535df2f6fSYi Zhang padOp.getLoc(), padOp.source(), dim)); 88635df2f6fSYi Zhang } else { 88735df2f6fSYi Zhang srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 88835df2f6fSYi Zhang } 88935df2f6fSYi Zhang } 89035df2f6fSYi Zhang // Strides of InsertSliceOp are all 1. 89135df2f6fSYi Zhang SmallVector<OpFoldResult> strides(sourceType.getRank(), 89235df2f6fSYi Zhang rewriter.getIndexAttr(1)); 89335df2f6fSYi Zhang rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 89435df2f6fSYi Zhang padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 89535df2f6fSYi Zhang 89635df2f6fSYi Zhang return success(); 89735df2f6fSYi Zhang } 89835df2f6fSYi Zhang 899060208b4SMatthias Springer LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 900060208b4SMatthias Springer tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 901060208b4SMatthias Springer auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 90224199f53SMatthias Springer if (!padOp) 90324199f53SMatthias Springer return failure(); 90424199f53SMatthias Springer // Only unit stride supported. 905060208b4SMatthias Springer if (!sliceOp.hasUnitStride()) 90624199f53SMatthias Springer return failure(); 90724199f53SMatthias Springer 908*61ba9f91SNicolas Vasilache Operation *tiledPadOp = 909*61ba9f91SNicolas Vasilache padOp 910*61ba9f91SNicolas Vasilache .getTiledImplementation( 911ba72cfe7SMaheshRavishankar rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 912*61ba9f91SNicolas Vasilache sliceOp.getMixedSizes(), /*tileDestOperands=*/false) 913*61ba9f91SNicolas Vasilache .front(); 91424199f53SMatthias Springer // All shapes are static and the data source is actually used. Rewrite into 91524199f53SMatthias Springer // pad_tensor(subtensor(x)). 916ba72cfe7SMaheshRavishankar rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 91724199f53SMatthias Springer return success(); 91824199f53SMatthias Springer } 9197b615a87SLei Zhang 9207b615a87SLei Zhang namespace { 9217b615a87SLei Zhang // The following are patterns for downscaling convolution ops with size-1 9227b615a87SLei Zhang // window dimensions. 9237b615a87SLei Zhang // 9247b615a87SLei Zhang // Note that we'd eventually want to write such transformations in a generic 9257b615a87SLei Zhang // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 9267b615a87SLei Zhang // and then turning back to named ops. But for now it's fine to have a few 9277b615a87SLei Zhang // patterns matching special ops to get started. 9287b615a87SLei Zhang 9297b615a87SLei Zhang /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D 9307b615a87SLei Zhang /// convolution ops. 9317b615a87SLei Zhang struct DownscaleSizeOneWindowed2DConvolution final 9327b615a87SLei Zhang : public OpRewritePattern<Conv2DNhwcHwcfOp> { 93398dbcff1Sgysit DownscaleSizeOneWindowed2DConvolution( 93498dbcff1Sgysit MLIRContext *context, 93598dbcff1Sgysit LinalgTransformationFilter filter = LinalgTransformationFilter(), 93698dbcff1Sgysit PatternBenefit benefit = 1) 93798dbcff1Sgysit : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), filter(filter) {} 9387b615a87SLei Zhang 9397b615a87SLei Zhang LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, 9407b615a87SLei Zhang PatternRewriter &rewriter) const override { 94198dbcff1Sgysit if (failed(filter.checkAndNotify(rewriter, convOp))) 94298dbcff1Sgysit return failure(); 94398dbcff1Sgysit if (convOp.hasBufferSemantics()) 9447b615a87SLei Zhang return failure(); // To be implemented 9457b615a87SLei Zhang 9467b615a87SLei Zhang Value input = convOp.inputs().front(); 94798dbcff1Sgysit Value kernel = convOp.inputs().back(); 9487b615a87SLei Zhang Value output = convOp.outputs().front(); 9497b615a87SLei Zhang 9507b615a87SLei Zhang auto inputType = input.getType().dyn_cast<RankedTensorType>(); 95198dbcff1Sgysit auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 9527b615a87SLei Zhang auto outputType = output.getType().dyn_cast<RankedTensorType>(); 9537b615a87SLei Zhang 95498dbcff1Sgysit auto kernelShape = kernelType.getShape(); 9557b615a87SLei Zhang auto outputShape = outputType.getShape(); 9567b615a87SLei Zhang 9577b615a87SLei Zhang // Only handle the case where at least one of the window dimensions is 9587b615a87SLei Zhang // of size 1. Other cases can rely on tiling to reduce to such cases. 95998dbcff1Sgysit int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 9607b615a87SLei Zhang int64_t ohSize = outputShape[1], owSize = outputShape[2]; 96198dbcff1Sgysit bool removeH = (khSize == 1 && ohSize == 1); 96298dbcff1Sgysit bool removeW = (kwSize == 1 && owSize == 1); 963aa373180SNicolas Vasilache if (!removeH && !removeW) 9647b615a87SLei Zhang return failure(); 9657b615a87SLei Zhang 9667b615a87SLei Zhang // Get new shapes and types for all operands by removing the size-1 9677b615a87SLei Zhang // dimension. 968aa373180SNicolas Vasilache using RTTBuilder = RankedTensorType::Builder; 969789c88e8SNicolas Vasilache RankedTensorType newInputType = 970789c88e8SNicolas Vasilache RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 97198dbcff1Sgysit RankedTensorType newKernelType = 97298dbcff1Sgysit RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 973789c88e8SNicolas Vasilache RankedTensorType newOutputType = 974789c88e8SNicolas Vasilache RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 9757b615a87SLei Zhang 976aa373180SNicolas Vasilache // Rank-reduce operands. 9777b615a87SLei Zhang Location loc = convOp.getLoc(); 978aa373180SNicolas Vasilache Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 979aa373180SNicolas Vasilache rewriter, loc, input, newInputType); 98098dbcff1Sgysit Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 98198dbcff1Sgysit rewriter, loc, kernel, newKernelType); 982aa373180SNicolas Vasilache Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 983aa373180SNicolas Vasilache rewriter, loc, output, newOutputType); 9847b615a87SLei Zhang 985aa373180SNicolas Vasilache // Rank-reduce strides and dilations too. 986aa373180SNicolas Vasilache // TODO: dropDim 1-liner helper. 987aa373180SNicolas Vasilache auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 988aa373180SNicolas Vasilache strides.erase(strides.begin() + (removeH ? 0 : 1)); 989aa373180SNicolas Vasilache auto stridesAttr = rewriter.getI64VectorAttr(strides); 990aa373180SNicolas Vasilache 991aa373180SNicolas Vasilache auto dilations = 992aa373180SNicolas Vasilache llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 993aa373180SNicolas Vasilache dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 994aa373180SNicolas Vasilache auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 9957b615a87SLei Zhang 9967b615a87SLei Zhang auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( 99798dbcff1Sgysit loc, newOutputType, ValueRange{newInput, newKernel}, 9987b615a87SLei Zhang ValueRange{newOutput}, stridesAttr, dilationsAttr); 9997b615a87SLei Zhang 1000aa373180SNicolas Vasilache // Insert back. 1001aa373180SNicolas Vasilache Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1002aa373180SNicolas Vasilache rewriter, loc, conv1DOp.getResult(0), output); 1003aa373180SNicolas Vasilache rewriter.replaceOp(convOp, inserted); 1004aa373180SNicolas Vasilache 100598dbcff1Sgysit filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 10067b615a87SLei Zhang return success(); 10077b615a87SLei Zhang }; 100898dbcff1Sgysit 100998dbcff1Sgysit private: 101098dbcff1Sgysit /// LinalgTransformMarker handles special attribute manipulations. 101198dbcff1Sgysit LinalgTransformationFilter filter; 10127b615a87SLei Zhang }; 10137b615a87SLei Zhang 1014b828506eSNicolas Vasilache /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) 1015b828506eSNicolas Vasilache /// dimensions into 1-D depthwise convolution ops. 1016b828506eSNicolas Vasilache struct DownscaleDepthwiseConv2DNhwcHwcOp final 1017b828506eSNicolas Vasilache : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> { 101898dbcff1Sgysit DownscaleDepthwiseConv2DNhwcHwcOp( 101998dbcff1Sgysit MLIRContext *context, 102098dbcff1Sgysit LinalgTransformationFilter filter = LinalgTransformationFilter(), 102198dbcff1Sgysit PatternBenefit benefit = 1) 102298dbcff1Sgysit : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit), 102398dbcff1Sgysit filter(filter) {} 1024b828506eSNicolas Vasilache 1025b828506eSNicolas Vasilache LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, 1026b828506eSNicolas Vasilache PatternRewriter &rewriter) const override { 102798dbcff1Sgysit if (failed(filter.checkAndNotify(rewriter, convOp))) 102898dbcff1Sgysit return failure(); 102998dbcff1Sgysit if (convOp.hasBufferSemantics()) 1030b828506eSNicolas Vasilache return failure(); // To be implemented 1031b828506eSNicolas Vasilache 1032b828506eSNicolas Vasilache Value input = convOp.inputs().front(); 1033b828506eSNicolas Vasilache Value kernel = convOp.inputs().back(); 1034b828506eSNicolas Vasilache Value output = convOp.outputs().front(); 1035b828506eSNicolas Vasilache 1036b828506eSNicolas Vasilache auto inputType = input.getType().dyn_cast<RankedTensorType>(); 1037b828506eSNicolas Vasilache auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 1038b828506eSNicolas Vasilache auto outputType = output.getType().dyn_cast<RankedTensorType>(); 1039b828506eSNicolas Vasilache 1040b828506eSNicolas Vasilache auto kernelShape = kernelType.getShape(); 1041b828506eSNicolas Vasilache auto outputShape = outputType.getShape(); 1042b828506eSNicolas Vasilache 1043b828506eSNicolas Vasilache // Only handle the case where at least one of the window dimensions is 1044b828506eSNicolas Vasilache // of size 1. Other cases can rely on tiling to reduce to such cases. 1045b828506eSNicolas Vasilache int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1046b828506eSNicolas Vasilache int64_t ohSize = outputShape[1], owSize = outputShape[2]; 1047b828506eSNicolas Vasilache bool removeH = (khSize == 1 && ohSize == 1); 1048b828506eSNicolas Vasilache bool removeW = (kwSize == 1 && owSize == 1); 1049b828506eSNicolas Vasilache if (!removeH && !removeW) 1050b828506eSNicolas Vasilache return failure(); 1051b828506eSNicolas Vasilache 1052b828506eSNicolas Vasilache // Get new shapes and types for all operands by removing the size-1 1053b828506eSNicolas Vasilache // dimension. 1054b828506eSNicolas Vasilache using RTTBuilder = RankedTensorType::Builder; 1055789c88e8SNicolas Vasilache RankedTensorType newInputType = 1056789c88e8SNicolas Vasilache RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 1057789c88e8SNicolas Vasilache RankedTensorType newKernelType = 1058789c88e8SNicolas Vasilache RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1059789c88e8SNicolas Vasilache RankedTensorType newOutputType = 1060789c88e8SNicolas Vasilache RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1061b828506eSNicolas Vasilache 1062b828506eSNicolas Vasilache // Rank-reduce operands. 1063b828506eSNicolas Vasilache Location loc = convOp.getLoc(); 1064b828506eSNicolas Vasilache Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1065b828506eSNicolas Vasilache rewriter, loc, input, newInputType); 1066b828506eSNicolas Vasilache Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1067b828506eSNicolas Vasilache rewriter, loc, kernel, newKernelType); 1068b828506eSNicolas Vasilache Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1069b828506eSNicolas Vasilache rewriter, loc, output, newOutputType); 1070b828506eSNicolas Vasilache 1071b828506eSNicolas Vasilache // Rank-reduce strides and dilations too. 1072b828506eSNicolas Vasilache // TODO: dropDim 1-liner helper. 1073b828506eSNicolas Vasilache auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1074b828506eSNicolas Vasilache strides.erase(strides.begin() + (removeH ? 0 : 1)); 1075b828506eSNicolas Vasilache auto stridesAttr = rewriter.getI64VectorAttr(strides); 1076b828506eSNicolas Vasilache 1077b828506eSNicolas Vasilache auto dilations = 1078b828506eSNicolas Vasilache llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1079b828506eSNicolas Vasilache dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1080b828506eSNicolas Vasilache auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1081b828506eSNicolas Vasilache 1082b828506eSNicolas Vasilache auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( 1083b828506eSNicolas Vasilache loc, newOutputType, ValueRange{newInput, newKernel}, 1084b828506eSNicolas Vasilache ValueRange{newOutput}, stridesAttr, dilationsAttr); 1085b828506eSNicolas Vasilache 1086b828506eSNicolas Vasilache // Insert back. 1087b828506eSNicolas Vasilache Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1088b828506eSNicolas Vasilache rewriter, loc, conv1DOp.getResult(0), output); 1089b828506eSNicolas Vasilache rewriter.replaceOp(convOp, inserted); 1090b828506eSNicolas Vasilache 109198dbcff1Sgysit filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1092b828506eSNicolas Vasilache return success(); 1093b828506eSNicolas Vasilache }; 109498dbcff1Sgysit 109598dbcff1Sgysit private: 109698dbcff1Sgysit /// LinalgTransformMarker handles special attribute manipulations. 109798dbcff1Sgysit LinalgTransformationFilter filter; 1098b828506eSNicolas Vasilache }; 1099b828506eSNicolas Vasilache 11007b615a87SLei Zhang } // namespace 11017b615a87SLei Zhang 110298dbcff1Sgysit void linalg::populateDecomposeConvolutionPatterns( 110398dbcff1Sgysit RewritePatternSet &patterns, LinalgTransformationFilter filter, 11047b615a87SLei Zhang PatternBenefit benefit) { 1105b828506eSNicolas Vasilache patterns.add<DownscaleSizeOneWindowed2DConvolution, 110698dbcff1Sgysit DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, 11077b615a87SLei Zhang benefit); 11087b615a87SLei Zhang } 1109