//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements logic and helpers to expose Linalg transforms as rewrite // patterns. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include #define DEBUG_TYPE "linalg-transforms" using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. //===----------------------------------------------------------------------===// // Marker used as attribute name in generated Linalg rewriting transformations. const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = "__internal_linalg_transform__"; mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef matchDisjunction, Optional replacement) : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), replacement(replacement) {} LogicalResult mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, Operation *op) const { auto attr = op->template getAttrOfType( LinalgTransforms::kLinalgTransformMarker); if (!attr) { // 1. Has no marker case and matchDisjunction is empty. if (matchDisjunction.empty()) return success(); // 2. Has no marker but was expecting a marker. return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diag << " does not have any marker from list: "; interleaveComma(matchDisjunction, diag); }); } // 4. Match explicit marker. for (auto marker : matchDisjunction) if (attr.getValue() == marker) return success(); // 5. Fail to match. return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diag << " does not have any marker from list: "; interleaveComma(matchDisjunction, diag); }); } void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const { if (replacement.hasValue()) op->setAttr(LinalgTransforms::kLinalgTransformMarker, rewriter.getStringAttr(replacement.getValue())); else op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, rewriter.getContext())); } LinalgTilingOptions & mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef ts) { SmallVector tileSizes(ts.begin(), ts.end()); tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart( &op->getParentOfType().getBody().front()); return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { Value v = b.create(op->getLoc(), s); return v; })); }; return *this; } /// Linalg base tiling pattern. mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( StringRef opName, MLIRContext *context, LinalgTilingOptions options, LinalgMarker marker, PatternBenefit benefit) : RewritePattern(opName, {}, benefit, context), marker(marker), options(options) {} LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { LinalgOp linalgOp = dyn_cast(op); if (!linalgOp) return failure(); if (failed(marker.checkAndNotify(rewriter, linalgOp))) return failure(); Optional res = tileLinalgOp(rewriter, linalgOp, options); if (!res) return failure(); // New marker if specified. marker.replaceLinalgMarker(rewriter, res->op.getOperation()); rewriter.eraseOp(op); return success(); } /// Linalg base interchange pattern. mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( StringRef opName, MLIRContext *context, ArrayRef interchangeVector, LinalgMarker marker, PatternBenefit benefit) : RewritePattern(opName, {}, benefit, context), marker(marker), interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { LinalgOp linalgOp = dyn_cast(op); if (!linalgOp) return failure(); if (failed(marker.checkAndNotify(rewriter, linalgOp))) return failure(); if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) return failure(); // TODO: figure out how this interplays with named ops. In particular this // should break the named op property. rewriter.updateRootInPlace(op, [&]() { interchange(linalgOp, interchangeVector); // New marker if specified. marker.replaceLinalgMarker(rewriter, op); }); return success(); } mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( StringRef opName, MLIRContext *context, LinalgPromotionOptions options, LinalgMarker marker, PatternBenefit benefit) : RewritePattern(opName, {}, benefit, context), marker(marker), options(options) {} LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { if (failed(marker.checkAndNotify(rewriter, op))) return failure(); if (failed(promoteSubviewsPrecondition(op, options))) return failure(); // TODO: We cannot use root update here. This pattern is creating other ops, // so if the promotion fails, those need to be cleaned up, which doesnt seem // to be happening here. So to fail properly, we should be cloning the op and // deleting the previous op. This needs more investigation. rewriter.startRootUpdate(op); Optional promotedOp = promoteSubViews(rewriter, op, options); if (!promotedOp) { rewriter.cancelRootUpdate(op); return op->emitError("subview promotion failed"); } rewriter.finalizeRootUpdate(op); marker.replaceLinalgMarker(rewriter, op); return success(); } mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( StringRef opName, MLIRContext *context, LinalgMarker marker, PatternBenefit benefit) : RewritePattern(opName, {}, benefit, context), marker(marker) {} LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { LinalgOp linalgOp = dyn_cast(op); if (!linalgOp) return failure(); if (failed(marker.checkAndNotify(rewriter, linalgOp))) return failure(); if (failed(vectorizeLinalgOpPrecondition(op))) return failure(); vectorizeLinalgOp(rewriter, op); rewriter.eraseOp(op); return success(); } LogicalResult mlir::linalg::applyStagedPatterns( Operation *op, ArrayRef stage1Patterns, const OwningRewritePatternList &stage2Patterns, function_ref stage3Lambda) { unsigned iteration = 0; (void)iteration; for (const auto &patterns : stage1Patterns) { LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" << *op); if (failed(applyPatternsAndFoldGreedily(op, patterns))) { LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); return failure(); } LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" << *op); if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); return failure(); } LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" << *op); if (stage3Lambda) { if (failed(stage3Lambda(op))) return failure(); LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" << *op); } } return success(); }