//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Parser/Parser.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/StringSet.h" using namespace mlir; using namespace mlir::linalg; using namespace mlir::transform; /// Extracts a vector of unsigned from an array attribute. Asserts if the /// attribute contains values other than intergers. May truncate. static SmallVector extractUIntArray(ArrayAttr attr) { SmallVector result; result.reserve(attr.size()); for (APInt value : attr.getAsValueRange()) result.push_back(value.getZExtValue()); return result; } namespace { /// A simple pattern rewriter that implements no special logic. class SimpleRewriter : public PatternRewriter { public: SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} }; } // namespace /// Attempts to apply the pattern specified as template argument to the given /// operation. The pattern is expected to have a `returningMatchAndRewrite` /// function that returns the "main" result or failure. Returns failure if the /// pattern failed to apply. Extra arguments are forwarded to the pattern /// constructor. template static FailureOr tryApply(Operation *operation, Args &&...args) { // Check if the given operation has the type expected by the pattern. using OpTy = typename llvm::function_traits< decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; auto op = dyn_cast(operation); if (!op) return failure(); // Apply the pattern directly to the op. PatternTy pattern(operation->getContext(), std::forward(args)...); SimpleRewriter rewriter(operation->getContext()); rewriter.setInsertionPoint(operation); auto result = pattern.returningMatchAndRewrite(op, rewriter); if (failed(result)) return failure(); return cast(result->getOperation()); } //===----------------------------------------------------------------------===// // DecomposeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::DecomposeOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { FailureOr windowed = tryApply(target); if (succeeded(windowed)) { results.push_back(*windowed); return DiagnosedSilenceableFailure(success()); } FailureOr depthwise = tryApply(target); if (succeeded(depthwise)) { results.push_back(*depthwise); return DiagnosedSilenceableFailure(success()); } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } //===----------------------------------------------------------------------===// // FuseOp //===----------------------------------------------------------------------===// /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. static LogicalResult applyTilingToAll(Operation *transformOp, ArrayRef payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref(LinalgOp)> applyFn) { SmallVector tiledLinalgOps; SmallVector> loopOps(numLoops); for (unsigned int i = 0; i < numLoops; ++i) loopOps[i].reserve(payloadOps.size()); for (Operation *target : payloadOps) { auto linalgOp = dyn_cast(target); if (!linalgOp) return transformOp->emitError("only LinalgOps are supported"); FailureOr tiled = applyFn(linalgOp); if (failed(tiled)) return failure(); tiledLinalgOps.push_back(tiled->op); if (tiled->loops.size() != numLoops) // Not enough loops were generated. This usually means that the input size // was smaller than the tiling size. // TODO: LinalgTilingPattern should return failure(). return failure(); for (unsigned int i = 0; i < numLoops; ++i) loopOps[i].push_back(tiled->loops[i]); } transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); for (unsigned int i = 0; i < numLoops; ++i) transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); return success(); } /// Parse a tiling-like operation that returns the tiled op as well as the /// created tile loops. The function counts the non-zero tile sizes to compute /// the number of results. static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, StringRef sizesAttrName) { OpAsmParser::UnresolvedOperand targetOperand; SMLoc opLoc = parser.getCurrentLocation(); if (parser.parseOperand(targetOperand) || parser.parseOptionalAttrDict(result.attributes)) return failure(); Attribute sizesAttr = result.attributes.get(sizesAttrName); if (!sizesAttr) return parser.emitError(opLoc) << "expected '" << sizesAttrName << "' attribute"; auto sizesArrayAttr = sizesAttr.dyn_cast(); if (!sizesArrayAttr) return parser.emitError(opLoc) << "'" << sizesAttrName << "' attribute must be an array"; Type pdlOpType = parser.getBuilder().getType(); size_t numExpectedLoops = sizesArrayAttr.size() - llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0); result.addTypes(SmallVector(numExpectedLoops + 1, pdlOpType)); if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) return failure(); return success(); } DiagnosedSilenceableFailure transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { LinalgTilingAndFusionOptions fusionOptions; fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes()); fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange()); LogicalResult result = applyTilingToAll( getOperation(), state.getPayloadOps(getTarget()), fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0), transformResults, [&](LinalgOp linalgOp) -> FailureOr { LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); SimpleRewriter rewriter(getContext()); rewriter.setInsertionPoint(linalgOp); FailureOr tileLoopNest = pattern.returningMatchAndRewrite(linalgOp, rewriter); if (failed(tileLoopNest)) return failure(); TiledLinalgOp tiledLinalgOp; tiledLinalgOp.op = tileLoopNest->getRootOp(); tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(), tileLoopNest->getLoopOps().end()}; return tiledLinalgOp; }); return DiagnosedSilenceableFailure(result); } ParseResult transform::FuseOp::parse(OpAsmParser &parser, OperationState &result) { return parseTileLikeOp( parser, result, transform::FuseOp::getTileSizesAttrName(result.name).getValue()); } void transform::FuseOp::print(OpAsmPrinter &p) { p << ' '; p << getTarget(); p.printOptionalAttrDict((*this)->getAttrs()); } LogicalResult transform::FuseOp::verify() { SmallVector permutation = extractFromI64ArrayAttr(getTileInterchange()); auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), permutation.begin(), permutation.end())) { return emitOpError() << "expects interchange to be a permutation, found " << getTileInterchange(); } return success(); } //===----------------------------------------------------------------------===// // FuseIntoContainingOp //===----------------------------------------------------------------------===// static FailureOr> tileAndFuse(Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) { auto tileableProducer = dyn_cast(producerOp); if (!tileableProducer) return failure(); // Search the producer slices accessed within the containing operation. // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe // evolve into an interface. SmallVector sliceOps; for (Operation *user : tileableProducer->getUsers()) { auto sliceOp = dyn_cast(user); if (!sliceOp) continue; if (!containingOp->isProperAncestor(sliceOp)) continue; sliceOps.push_back(sliceOp); } // Check for a non-empty list of fusion opportunities. if (sliceOps.empty()) return failure(); SmallVector destinationOperands = tileableProducer.getDestinationOperands(rewriter); // Try to fuse the producer in-place. SmallVector fusedOps; for (tensor::ExtractSliceOp sliceOp : sliceOps) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(sliceOp); // Tile the producer. FailureOr tiledProducer = tileableProducer.generateResultTileValue( rewriter, /*resultNumber=*/0, destinationOperands, sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), true); if (failed(tiledProducer)) return failure(); fusedOps.push_back(tiledProducer->getDefiningOp()); } // Replace the extract op. for (const auto &en : enumerate(sliceOps)) rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0)); return fusedOps; } static FailureOr> cloneAndFuse(Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) { // Gather all uses inside the containing op. SmallVector uses; for (OpResult result : producerOp->getOpResults()) for (OpOperand &use : result.getUses()) if (containingOp->isProperAncestor(use.getOwner())) uses.push_back(&use); // Check for a non-empty list of fusion opportunities. if (uses.empty()) return failure(); // Clone and fuse inside the containing op. SmallVector fusedOps; for (OpOperand *use : uses) { unsigned resultNumber = use->get().cast().getResultNumber(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(use->getOwner()); Operation *cloned = rewriter.clone(*producerOp); rewriter.updateRootInPlace( use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); }); fusedOps.push_back(cloned); } return fusedOps; } DiagnosedSilenceableFailure transform::FuseIntoContainingOp::apply(transform::TransformResults &results, transform::TransformState &state) { SmallVector fusedOps; ArrayRef producerOps = state.getPayloadOps(getProducerOp()); for (Operation *producerOp : producerOps) { if (producerOp->getNumResults() != 1) { Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note); diag << "op with != 1 results not supported"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } } ArrayRef containingOps = state.getPayloadOps(getContainingOp()); if (containingOps.size() != 1) return DiagnosedSilenceableFailure( this->emitOpError("requires exactly one containing_op handle")); Operation *containingOp = containingOps.front(); // Helper function to find the next producer that should be fused. Take any // producer that has a use inside the containing op. SmallVector remainingProducers(producerOps.begin(), producerOps.end()); auto getNextProducer = [&]() -> FailureOr { for (const auto &it : enumerate(remainingProducers)) { Operation *producerOp = it.value(); bool hasUseInContainingOp = any_of(producerOp->getUsers(), [&](Operation *op) { return containingOp->isProperAncestor(op); }); // TODO: When resolving the TODO below (no duplicate ops), take an op that // has no use among the remaining producers. This is a topological // sorting. if (hasUseInContainingOp) { remainingProducers.erase(remainingProducers.begin() + it.index()); return producerOp; } } return failure(); }; IRRewriter rewriter(getContext()); while (!remainingProducers.empty()) { auto nextProducer = getNextProducer(); if (failed(nextProducer)) { Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note); diag << "could not fuse ops into container"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } Operation *producerOp = *nextProducer; // TODO: If there are multiple uses of the producer in the containing op, we // currently tile/clone the op multiple times (once per use). In some cases, // we can tile/clone once and reuse the value for each use. Futhermore, // producers should then be traversed according to a topological sorting. auto tiled = tileAndFuse(producerOp, containingOp, rewriter); if (succeeded(tiled)) fusedOps.append(*tiled); auto cloned = cloneAndFuse(producerOp, containingOp, rewriter); if (succeeded(cloned)) fusedOps.append(*cloned); if (failed(tiled) && failed(cloned)) { Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note); diag << "could not fuse into containing op"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } } results.set(getFusedOp().cast(), fusedOps); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // GeneralizeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { // Exit early if no transformation is needed. if (isa(target)) { results.push_back(target); return DiagnosedSilenceableFailure(success()); } FailureOr generic = tryApply(target); if (succeeded(generic)) { results.push_back(generic->getOperation()); return DiagnosedSilenceableFailure(success()); } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } //===----------------------------------------------------------------------===// // InterchangeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::InterchangeOp::applyToOne(linalg::GenericOp target, SmallVectorImpl &results, transform::TransformState &state) { SmallVector interchangeVector = extractUIntArray(getIteratorInterchange()); // Exit early if no transformation is needed. if (interchangeVector.empty()) { results.push_back(target); return DiagnosedSilenceableFailure(success()); } SimpleRewriter rewriter(target->getContext()); FailureOr res = interchangeGenericOp(rewriter, target, interchangeVector); if (failed(res)) return DiagnosedSilenceableFailure::definiteFailure(); results.push_back(res->getOperation()); return DiagnosedSilenceableFailure(success()); } LogicalResult transform::InterchangeOp::verify() { SmallVector permutation = extractUIntArray(getIteratorInterchange()); auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), permutation.begin(), permutation.end())) { return emitOpError() << "expects iterator_interchange to be a permutation, found " << getIteratorInterchange(); } return success(); } //===---------------------------------------------------------------------===// // MatchOp //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchOp::apply(transform::TransformResults &results, transform::TransformState &state) { llvm::StringSet<> strs; if (getOps().has_value()) strs.insert(getOps()->getAsValueRange().begin(), getOps()->getAsValueRange().end()); ArrayRef payloadOps = state.getPayloadOps(getTarget()); if (payloadOps.size() != 1) return DiagnosedSilenceableFailure( this->emitOpError("requires exactly one target handle")); SmallVector res; auto matchFun = [&](Operation *op) { if (getOps().has_value() && !strs.contains(op->getName().getStringRef())) return WalkResult::advance(); // Interfaces cannot be matched by name, just by ID. // So we specifically encode the interfaces we care about for this op. if (getInterface().has_value()) { auto iface = getInterface().value(); if (iface == transform::MatchInterfaceEnum::LinalgOp && !isa(op)) return WalkResult::advance(); if (iface == transform::MatchInterfaceEnum::TilingInterface && isa(op)) return WalkResult::advance(); } if (getAttribute().has_value() && !op->hasAttr(getAttribute().value())) return WalkResult::advance(); // All constraints are satisfied. res.push_back(op); return WalkResult::advance(); }; payloadOps.front()->walk(matchFun); results.set(getResult().cast(), res); return DiagnosedSilenceableFailure(success()); } //===---------------------------------------------------------------------===// // MultiTileSizesOp //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( LinalgOp target, SmallVector &results, TransformState &state) { OpBuilder builder(target.getContext()); builder.setInsertionPoint(target); OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); OpFoldResult divisor = builder.getIndexAttr(getDivisor()); FailureOr spec = computeMultiTileSizes( builder, target, getDimension(), targetSize, divisor); if (failed(spec)) { return emitSilenceableError() << "could not generate tile size computation"; } AffineExpr s0 = builder.getAffineSymbolExpr(0); AffineExpr s1 = builder.getAffineSymbolExpr(1); Operation *splitPoint = makeComposedAffineApply(builder, target.getLoc(), s0 * s1, {spec->lowTileSize, spec->lowTripCount}); Operation *lowTileSize = spec->lowTileSize.getDefiningOp(); Operation *highTileSize = spec->highTileSize.getDefiningOp(); assert(lowTileSize && highTileSize && splitPoint && "tile sizes are not produced by operations"); results.reserve(results.size() + 3); results.push_back(lowTileSize); results.push_back(highTileSize); results.push_back(splitPoint); return DiagnosedSilenceableFailure::success(); } void transform::MultiTileSizesOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getTarget(), effects); producesHandle(getResults(), effects); modifiesPayload(effects); } //===---------------------------------------------------------------------===// // PadOp //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::PadOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { // Convert the integer packing flags to booleans. SmallVector packPaddings; for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings())) packPaddings.push_back(static_cast(packPadding)); // Convert the padding values to attributes. SmallVector paddingValues; for (auto const &it : llvm::zip(getPaddingValues(), target->getOperandTypes())) { Attribute attr = std::get<0>(it); Type elementType = getElementTypeOrSelf(std::get<1>(it)); // Try to parse string attributes to obtain an attribute of element type. if (auto stringAttr = attr.dyn_cast()) { paddingValues.push_back( parseAttribute(attr.cast(), elementType)); if (!paddingValues.back()) { auto diag = this->emitOpError("expects a padding that parses to ") << elementType << ", got " << std::get<0>(it); diag.attachNote(target.getLoc()) << "when applied to this op"; return DiagnosedSilenceableFailure::definiteFailure(); } continue; } // Otherwise, add the attribute directly. if (attr.getType() != elementType) { auto diag = this->emitOpError("expects a padding value of type ") << elementType << ", got " << attr; diag.attachNote(target.getLoc()) << "when applied to this op"; return DiagnosedSilenceableFailure::definiteFailure(); } paddingValues.push_back(attr); } // Extract the transpose vectors. SmallVector> transposePaddings; for (Attribute transposeVector : getTransposePaddings().cast()) transposePaddings.push_back( extractFromI64ArrayAttr(transposeVector.cast())); LinalgPaddingOptions paddingOptions; paddingOptions.setPaddingValues(paddingValues); paddingOptions.setPaddingDimensions( extractFromI64ArrayAttr(getPaddingDimensions())); paddingOptions.setPackPaddings(packPaddings); paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings())); paddingOptions.setTransposePaddings(transposePaddings); FailureOr result = tryApply(target, paddingOptions); if (succeeded(result)) { results.push_back(result->getOperation()); return DiagnosedSilenceableFailure(success()); } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } LogicalResult transform::PadOp::verify() { SmallVector packPaddings = extractFromI64ArrayAttr(getPackPaddings()); if (any_of(packPaddings, [](int64_t packPadding) { return packPadding != 0 && packPadding != 1; })) { return emitOpError() << "expects pack_paddings to contain booleans (0/1), found " << getPackPaddings(); } SmallVector paddingDimensions = extractFromI64ArrayAttr(getPaddingDimensions()); if (any_of(paddingDimensions, [](int64_t paddingDimension) { return paddingDimension < 0; })) { return emitOpError() << "expects padding_dimensions to contain positive integers, found " << getPaddingDimensions(); } SmallVector hoistPaddings = extractFromI64ArrayAttr(getHoistPaddings()); if (any_of(hoistPaddings, [](int64_t hoistPadding) { return hoistPadding < 0; })) { return emitOpError() << "expects hoist_paddings to contain positive integers, found " << getHoistPaddings(); } ArrayAttr transposes = getTransposePaddings(); for (Attribute attr : transposes) { SmallVector transpose = extractFromI64ArrayAttr(attr); auto sequence = llvm::to_vector(llvm::seq(0, transpose.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(), transpose.end())) { return emitOpError() << "expects transpose_paddings to be a permutation, found " << attr; } } return success(); } //===----------------------------------------------------------------------===// // PromoteOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::PromoteOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { LinalgPromotionOptions promotionOptions; if (!getOperandsToPromote().empty()) promotionOptions = promotionOptions.setOperandsToPromote( extractFromI64ArrayAttr(getOperandsToPromote())); if (getUseFullTilesByDefault()) promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( getUseFullTilesByDefault()); if (getUseAlloca()) promotionOptions = promotionOptions.setUseAlloca(getUseAlloca()); if (!getUseFullTileBuffers().empty()) promotionOptions = promotionOptions.setUseFullTileBuffers( llvm::to_vector(getUseFullTileBuffers().getAsValueRange())); if (getAlignment().has_value()) promotionOptions = promotionOptions.setAlignment(*getAlignment()); if (failed(promoteSubviewsPrecondition(target, promotionOptions))) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); SimpleRewriter rewriter(target->getContext()); rewriter.setInsertionPoint(target); FailureOr res = promoteSubViews(rewriter, target, promotionOptions); if (failed(res)) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); results.push_back(target); return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// // ScalarizeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { LinalgTilingOptions tilingOptions; tilingOptions.scalarizeDynamicDims(); // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile // sizes and asserts that it is not already set. SmallVector emptyTileSizes; LinalgTilingPattern pattern(getContext(), tilingOptions); SimpleRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr result = pattern.returningMatchAndRewrite(target, rewriter); if (failed(result)) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); results.push_back(result->op); return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// // SplitOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, TransformState &state) { // Collect the dynamic split points if provided. ArrayRef payload = state.getPayloadOps(getTarget()); SimpleRewriter rewriter(getContext()); SmallVector splitPoints; splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { auto diag = DiagnosedSilenceableFailure::success(); splitPoints = llvm::to_vector(llvm::map_range( state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { diag = emitSilenceableError() << "expected dynamic split point handle to point to a " "single-result index-typed op"; diag.attachNote(op->getLoc()) << "dynamic split point"; } return OpFoldResult(op->getResult(0)); })); if (!diag.succeeded()) return diag; if (splitPoints.size() != payload.size()) { emitError() << "expected the dynamic split point handle to point to as " "many operations (" << splitPoints.size() << ") as the target handle (" << payload.size() << ")"; return DiagnosedSilenceableFailure::definiteFailure(); } } else { splitPoints.resize(payload.size(), rewriter.getIndexAttr(getStaticSplitPoint())); } // Split each target operation. SmallVector first, second; for (const auto &pair : llvm::zip(payload, splitPoints)) { Operation *target = std::get<0>(pair); auto linalgOp = dyn_cast(target); if (!linalgOp) { auto diag = emitSilenceableError() << "only applies to structured ops"; diag.attachNote(target->getLoc()) << "target op"; return diag; } if (getDimension() >= linalgOp.getNumLoops()) { auto diag = emitSilenceableError() << "dimension " << getDimension() << " does not exist in target op"; diag.attachNote(target->getLoc()) << "target op"; return diag; } rewriter.setInsertionPoint(linalgOp); std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair)); } results.set(getFirst().cast(), first); results.set(getSecond().cast(), second); return DiagnosedSilenceableFailure::success(); } void SplitOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); if (getDynamicSplitPoint()) onlyReadsHandle(getDynamicSplitPoint(), effects); producesHandle(getResults(), effects); modifiesPayload(effects); } ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; IntegerAttr staticSplitPoint; auto pdlOperationType = pdl::OperationType::get(parser.getBuilder().getContext()); if (parser.parseOperand(target) || parser.resolveOperand(target, pdlOperationType, result.operands) || parser.parseKeyword("after")) return failure(); OptionalParseResult dynamicPointParseResult = parser.parseOptionalOperand(dynamicSplitPoint); if (!dynamicPointParseResult.hasValue()) { int64_t staticSplitPointValue; if (failed(parser.parseInteger(staticSplitPointValue))) return failure(); staticSplitPoint = parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); } else { if (failed(*dynamicPointParseResult) || parser.resolveOperand(dynamicSplitPoint, pdlOperationType, result.operands)) { return failure(); } staticSplitPoint = parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize); } result.addAttribute( SplitOp::getStaticSplitPointAttrName(result.name).getValue(), staticSplitPoint); if (failed(parser.parseOptionalAttrDict(result.attributes))) return failure(); result.addTypes({pdlOperationType, pdlOperationType}); return success(); } void SplitOp::print(OpAsmPrinter &printer) { printer << " " << getTarget() << " after "; int64_t staticSplitSize = static_cast(getStaticSplitPoint()); if (staticSplitSize != ShapedType::kDynamicSize) printer << staticSplitSize; else printer << getDynamicSplitPoint(); printer << " "; printer.printOptionalAttrDict(getOperation()->getAttrs(), {getStaticSplitPointAttrName()}); } LogicalResult SplitOp::verify() { if ((static_cast(getStaticSplitPoint()) != ShapedType::kDynamicSize) ^ (getDynamicSplitPoint() == nullptr)) { return emitOpError() << "expects either a dynamic or a static split point to be provided"; } return success(); } //===----------------------------------------------------------------------===// // SplitReductionOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { return std::pair(getSplitFactor(), getInsertSplitDimension()); }; SimpleRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr splitResult = (getUseScalingAlgorithm()) ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) : splitReduction(rewriter, target, splitFn, getUseAlloc()); if (failed(splitResult)) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); results.push_back(splitResult->initOrAlloc); results.push_back(splitResult->fillOp); results.push_back(splitResult->splitLinalgOp); results.push_back(splitResult->resultCombiningLinalgOp); return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// // TileOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::TileOp::apply(TransformResults &transformResults, TransformState &state) { LinalgTilingOptions tilingOptions; SmallVector tileSizes = extractFromI64ArrayAttr(getStaticSizes()); ArrayRef targets = state.getPayloadOps(getTarget()); SmallVector> dynamicSizeProducers; dynamicSizeProducers.reserve(getDynamicSizes().size()); for (Value dynamicSizeProducerHandle : getDynamicSizes()) { dynamicSizeProducers.push_back( state.getPayloadOps(dynamicSizeProducerHandle)); if (dynamicSizeProducers.back().size() != targets.size()) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected as many dynamic size-producing operations (" << dynamicSizeProducers.back().size() << ") as target ops (" << targets.size() << ")"; diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; return diag; } for (Operation *op : dynamicSizeProducers.back()) { if (op->getNumResults() == 1 && op->getResult(0).getType().isa()) continue; DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected sizes to be produced by ops " "with a single index-type result"; diag.attachNote(op->getLoc()) << "size producer op"; diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; return diag; } } SmallVector tiled; SmallVector, 4> loops; loops.resize(getLoops().size()); for (auto &en : llvm::enumerate(targets)) { auto linalgOp = dyn_cast(en.value()); if (!linalgOp) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "only linalg ops are supported"; diag.attachNote(en.value()->getLoc()) << "target op"; return diag; } unsigned index = en.index(); if (!tileSizes.empty()) { tilingOptions.setTileSizeComputationFunction( [&, index](OpBuilder &b, Operation *) { SmallVector sizes; sizes.reserve(tileSizes.size()); unsigned dynamicIdx = 0; for (OpFoldResult ofr : getMixedSizes()) { if (auto attr = ofr.dyn_cast()) { sizes.push_back(b.create( getLoc(), attr.cast().getInt())); } else { sizes.push_back( dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); } } return sizes; }); } tilingOptions.setInterchange(extractUIntArray(getInterchange())); LinalgTilingPattern pattern(getContext(), tilingOptions); SimpleRewriter rewriter(linalgOp.getContext()); FailureOr tiledOp = pattern.returningMatchAndRewrite(linalgOp, rewriter); if (failed(tiledOp)) return DiagnosedSilenceableFailure::definiteFailure(); tiled.push_back(tiledOp->op); for (const auto &en2 : llvm::enumerate(tiledOp->loops)) loops[en2.index()].push_back(en2.value()); } transformResults.set(getTiledLinalgOp().cast(), tiled); for (const auto &en : llvm::enumerate(loops)) transformResults.set(getLoops()[en.index()].cast(), en.value()); return DiagnosedSilenceableFailure::success(); } SmallVector transform::TileOp::getMixedSizes() { ValueRange dynamic = getDynamicSizes(); SmallVector tileSizes = extractFromI64ArrayAttr(getStaticSizes()); SmallVector results; results.reserve(tileSizes.size()); unsigned dynamicPos = 0; Builder builder(getContext()); for (int64_t size : tileSizes) { if (size == ShapedType::kDynamicSize) { results.push_back(dynamic[dynamicPos++]); } else { results.push_back(builder.getIndexAttr(size)); } } return results; } ParseResult transform::TileOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target; SmallVector dynamicSizes; ArrayAttr staticSizes; auto pdlOperationType = pdl::OperationType::get(parser.getContext()); if (parser.parseOperand(target) || parser.resolveOperand(target, pdlOperationType, result.operands) || parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) || parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || parser.parseOptionalAttrDict(result.attributes)) return ParseResult::failure(); result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); size_t numExpectedLoops = staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0); result.addTypes(SmallVector(numExpectedLoops + 1, pdlOperationType)); return success(); } void TileOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(), getStaticSizes()); p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); } void transform::TileOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); onlyReadsHandle(getDynamicSizes(), effects); producesHandle(getTiledLinalgOp(), effects); producesHandle(getLoops(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // TileToForeachThreadOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne( TilingInterface target, SmallVectorImpl &results, transform::TransformState &state) { IRRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); auto maybeThreadDimMappingAttr = getThreadDimMapping(); auto dimMapping = llvm::to_vector(maybeThreadDimMappingAttr ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) : ArrayRef{}); FailureOr tilingResult = failure(); if (Optional numThreads = getNumThreads()) tilingResult = linalg::tileToForeachThreadOp( rewriter, target, getAsOpFoldResult(*numThreads), dimMapping); if (Optional tileSizes = getTileSizes()) tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping); if (failed(tilingResult)) return emitDefaultSilenceableFailure(target); rewriter.replaceOp(target, tilingResult->tileOp->getResults()); results.assign({tilingResult->tileOp, tilingResult->tiledOp}); return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// // VectorizeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::VectorizeOp::applyToOne(Operation *target, SmallVectorImpl &results, transform::TransformState &state) { if (!target->hasTrait()) { auto diag = this->emitOpError("requires isolated-from-above targets"); diag.attachNote(target->getLoc()) << "non-isolated target"; return DiagnosedSilenceableFailure::definiteFailure(); } MLIRContext *ctx = getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); vector::populateVectorReductionToContractPatterns(patterns); patterns.add(ctx, /*benefit=*/2); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); if (getVectorizePadding()) linalg::populatePadOpVectorizationPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); results.push_back(target); return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// namespace { /// Registers new ops and declares PDL as dependent dialect since the additional /// ops are using PDL types for operands and results. class LinalgTransformDialectExtension : public transform::TransformDialectExtension< LinalgTransformDialectExtension> { public: using Base::Base; void init() { declareDependentDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" >(); } }; } // namespace #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" void mlir::linalg::registerTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); }