//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Parser/Parser.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::linalg; using namespace mlir::transform; /// Extracts a vector of unsigned from an array attribute. Asserts if the /// attribute contains values other than intergers. May truncate. static SmallVector extractUIntArray(ArrayAttr attr) { SmallVector result; result.reserve(attr.size()); for (APInt value : attr.getAsValueRange()) result.push_back(value.getZExtValue()); return result; } namespace { /// A simple pattern rewriter that implements no special logic. class SimpleRewriter : public PatternRewriter { public: SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} }; } // namespace /// Attempts to apply the pattern specified as template argument to the given /// operation. The pattern is expected to have a `returningMatchAndRewrite` /// function that returns the "main" result or failure. Returns failure if the /// pattern failed to apply. Extra arguments are forwarded to the pattern /// constructor. template static FailureOr tryApply(Operation *operation, Args &&...args) { // Check if the given operation has the type expected by the pattern. using OpTy = typename llvm::function_traits< decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; auto op = dyn_cast(operation); if (!op) return failure(); // Apply the pattern directly to the op. PatternTy pattern(operation->getContext(), std::forward(args)...); SimpleRewriter rewriter(operation->getContext()); rewriter.setInsertionPoint(operation); auto result = pattern.returningMatchAndRewrite(op, rewriter); if (failed(result)) return failure(); return cast(result->getOperation()); } //===----------------------------------------------------------------------===// // DecomposeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::DecomposeOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { FailureOr windowed = tryApply(target); if (succeeded(windowed)) { results.push_back(*windowed); return DiagnosedSilenceableFailure(success()); } FailureOr depthwise = tryApply(target); if (succeeded(depthwise)) { results.push_back(*depthwise); return DiagnosedSilenceableFailure(success()); } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } //===----------------------------------------------------------------------===// // FuseOp //===----------------------------------------------------------------------===// /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. static LogicalResult applyTilingToAll(Operation *transformOp, ArrayRef payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref(LinalgOp)> applyFn) { SmallVector tiledLinalgOps; SmallVector> loopOps(numLoops); for (unsigned int i = 0; i < numLoops; ++i) loopOps[i].reserve(payloadOps.size()); for (Operation *target : payloadOps) { auto linalgOp = dyn_cast(target); if (!linalgOp) return transformOp->emitError("only LinalgOps are supported"); FailureOr tiled = applyFn(linalgOp); if (failed(tiled)) return failure(); tiledLinalgOps.push_back(tiled->op); if (tiled->loops.size() != numLoops) // Not enough loops were generated. This usually means that the input size // was smaller than the tiling size. // TODO: LinalgTilingPattern should return failure(). return failure(); for (unsigned int i = 0; i < numLoops; ++i) loopOps[i].push_back(tiled->loops[i]); } transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); for (unsigned int i = 0; i < numLoops; ++i) transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); return success(); } /// Parse a tiling-like operation that returns the tiled op as well as the /// created tile loops. The function counts the non-zero tile sizes to compute /// the number of results. static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, StringRef sizesAttrName) { OpAsmParser::UnresolvedOperand targetOperand; SMLoc opLoc = parser.getCurrentLocation(); if (parser.parseOperand(targetOperand) || parser.parseOptionalAttrDict(result.attributes)) return failure(); Attribute sizesAttr = result.attributes.get(sizesAttrName); if (!sizesAttr) return parser.emitError(opLoc) << "expected '" << sizesAttrName << "' attribute"; auto sizesArrayAttr = sizesAttr.dyn_cast(); if (!sizesArrayAttr) return parser.emitError(opLoc) << "'" << sizesAttrName << "' attribute must be an array"; Type pdlOpType = parser.getBuilder().getType(); size_t numExpectedLoops = sizesArrayAttr.size() - llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0); result.addTypes(SmallVector(numExpectedLoops + 1, pdlOpType)); if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) return failure(); return success(); } DiagnosedSilenceableFailure transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { LinalgTilingAndFusionOptions fusionOptions; fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes()); fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange()); LogicalResult result = applyTilingToAll( getOperation(), state.getPayloadOps(getTarget()), fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0), transformResults, [&](LinalgOp linalgOp) -> FailureOr { LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); SimpleRewriter rewriter(getContext()); rewriter.setInsertionPoint(linalgOp); FailureOr tileLoopNest = pattern.returningMatchAndRewrite(linalgOp, rewriter); if (failed(tileLoopNest)) return failure(); TiledLinalgOp tiledLinalgOp; tiledLinalgOp.op = tileLoopNest->getRootOp(); tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(), tileLoopNest->getLoopOps().end()}; return tiledLinalgOp; }); return DiagnosedSilenceableFailure(result); } ParseResult transform::FuseOp::parse(OpAsmParser &parser, OperationState &result) { return parseTileLikeOp( parser, result, transform::FuseOp::getTileSizesAttrName(result.name).getValue()); } void transform::FuseOp::print(OpAsmPrinter &p) { p << ' '; p << getTarget(); p.printOptionalAttrDict((*this)->getAttrs()); } LogicalResult transform::FuseOp::verify() { SmallVector permutation = extractFromI64ArrayAttr(getTileInterchange()); auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), permutation.begin(), permutation.end())) { return emitOpError() << "expects interchange to be a permutation, found " << getTileInterchange(); } return success(); } //===----------------------------------------------------------------------===// // GeneralizeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { // Exit early if no transformation is needed. if (isa(target)) { results.push_back(target); return DiagnosedSilenceableFailure(success()); } FailureOr generic = tryApply(target); if (succeeded(generic)) { results.push_back(generic->getOperation()); return DiagnosedSilenceableFailure(success()); } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } //===----------------------------------------------------------------------===// // InterchangeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::InterchangeOp::applyToOne(linalg::GenericOp target, SmallVectorImpl &results, transform::TransformState &state) { SmallVector interchangeVector = extractUIntArray(getIteratorInterchange()); // Exit early if no transformation is needed. if (interchangeVector.empty()) { results.push_back(target); return DiagnosedSilenceableFailure(success()); } SimpleRewriter rewriter(target->getContext()); FailureOr res = interchangeGenericOp(rewriter, target, interchangeVector); if (failed(res)) return DiagnosedSilenceableFailure::definiteFailure(); results.push_back(res->getOperation()); return DiagnosedSilenceableFailure(success()); } LogicalResult transform::InterchangeOp::verify() { SmallVector permutation = extractUIntArray(getIteratorInterchange()); auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), permutation.begin(), permutation.end())) { return emitOpError() << "expects iterator_interchange to be a permutation, found " << getIteratorInterchange(); } return success(); } //===---------------------------------------------------------------------===// // MultiTileSizesOp //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( LinalgOp target, SmallVector &results, TransformState &state) { OpBuilder builder(target.getContext()); builder.setInsertionPoint(target); OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); OpFoldResult divisor = builder.getIndexAttr(getDivisor()); FailureOr spec = computeMultiTileSizes( builder, target, getDimension(), targetSize, divisor); if (failed(spec)) { return emitSilenceableError() << "could not generate tile size computation"; } AffineExpr s0 = builder.getAffineSymbolExpr(0); AffineExpr s1 = builder.getAffineSymbolExpr(1); Operation *splitPoint = makeComposedAffineApply(builder, target.getLoc(), s0 * s1, {spec->lowTileSize, spec->lowTripCount}); Operation *lowTileSize = spec->lowTileSize.getDefiningOp(); Operation *highTileSize = spec->highTileSize.getDefiningOp(); assert(lowTileSize && highTileSize && splitPoint && "tile sizes are not produced by operations"); results.reserve(results.size() + 3); results.push_back(lowTileSize); results.push_back(highTileSize); results.push_back(splitPoint); return DiagnosedSilenceableFailure::success(); } void transform::MultiTileSizesOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getTarget(), effects); producesHandle(getResults(), effects); modifiesPayload(effects); } //===---------------------------------------------------------------------===// // PadOp //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::PadOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { // Convert the integer packing flags to booleans. SmallVector packPaddings; for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings())) packPaddings.push_back(static_cast(packPadding)); // Convert the padding values to attributes. SmallVector paddingValues; for (auto const &it : llvm::zip(getPaddingValues(), target->getOperandTypes())) { Attribute attr = std::get<0>(it); Type elementType = getElementTypeOrSelf(std::get<1>(it)); // Try to parse string attributes to obtain an attribute of element type. if (auto stringAttr = attr.dyn_cast()) { paddingValues.push_back( parseAttribute(attr.cast(), elementType)); if (!paddingValues.back()) { auto diag = this->emitOpError("expects a padding that parses to ") << elementType << ", got " << std::get<0>(it); diag.attachNote(target.getLoc()) << "when applied to this op"; return DiagnosedSilenceableFailure::definiteFailure(); } continue; } // Otherwise, add the attribute directly. if (attr.getType() != elementType) { auto diag = this->emitOpError("expects a padding value of type ") << elementType << ", got " << attr; diag.attachNote(target.getLoc()) << "when applied to this op"; return DiagnosedSilenceableFailure::definiteFailure(); } paddingValues.push_back(attr); } // Extract the transpose vectors. SmallVector> transposePaddings; for (Attribute transposeVector : getTransposePaddings().cast()) transposePaddings.push_back( extractFromI64ArrayAttr(transposeVector.cast())); LinalgPaddingOptions paddingOptions; paddingOptions.setPaddingValues(paddingValues); paddingOptions.setPaddingDimensions( extractFromI64ArrayAttr(getPaddingDimensions())); paddingOptions.setPackPaddings(packPaddings); paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings())); paddingOptions.setTransposePaddings(transposePaddings); FailureOr result = tryApply(target, paddingOptions); if (succeeded(result)) { results.push_back(result->getOperation()); return DiagnosedSilenceableFailure(success()); } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } LogicalResult transform::PadOp::verify() { SmallVector packPaddings = extractFromI64ArrayAttr(getPackPaddings()); if (any_of(packPaddings, [](int64_t packPadding) { return packPadding != 0 && packPadding != 1; })) { return emitOpError() << "expects pack_paddings to contain booleans (0/1), found " << getPackPaddings(); } SmallVector paddingDimensions = extractFromI64ArrayAttr(getPaddingDimensions()); if (any_of(paddingDimensions, [](int64_t paddingDimension) { return paddingDimension < 0; })) { return emitOpError() << "expects padding_dimensions to contain positive integers, found " << getPaddingDimensions(); } SmallVector hoistPaddings = extractFromI64ArrayAttr(getHoistPaddings()); if (any_of(hoistPaddings, [](int64_t hoistPadding) { return hoistPadding < 0; })) { return emitOpError() << "expects hoist_paddings to contain positive integers, found " << getHoistPaddings(); } ArrayAttr transposes = getTransposePaddings(); for (Attribute attr : transposes) { SmallVector transpose = extractFromI64ArrayAttr(attr); auto sequence = llvm::to_vector(llvm::seq(0, transpose.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(), transpose.end())) { return emitOpError() << "expects transpose_paddings to be a permutation, found " << attr; } } return success(); } //===----------------------------------------------------------------------===// // PromoteOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::PromoteOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { LinalgPromotionOptions promotionOptions; if (!getOperandsToPromote().empty()) promotionOptions = promotionOptions.setOperandsToPromote( extractFromI64ArrayAttr(getOperandsToPromote())); if (getUseFullTilesByDefault()) promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( getUseFullTilesByDefault()); if (getUseAlloca()) promotionOptions = promotionOptions.setUseAlloca(getUseAlloca()); if (!getUseFullTileBuffers().empty()) promotionOptions = promotionOptions.setUseFullTileBuffers( llvm::to_vector(getUseFullTileBuffers().getAsValueRange())); if (getAlignment().has_value()) promotionOptions = promotionOptions.setAlignment(*getAlignment()); if (failed(promoteSubviewsPrecondition(target, promotionOptions))) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); SimpleRewriter rewriter(target->getContext()); rewriter.setInsertionPoint(target); FailureOr res = promoteSubViews(rewriter, target, promotionOptions); if (failed(res)) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); results.push_back(target); return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// // ScalarizeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { LinalgTilingOptions tilingOptions; tilingOptions.scalarizeDynamicDims(); // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile // sizes and asserts that it is not already set. SmallVector emptyTileSizes; LinalgTilingPattern pattern(getContext(), tilingOptions); SimpleRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr result = pattern.returningMatchAndRewrite(target, rewriter); if (failed(result)) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); results.push_back(result->op); return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// // SplitOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, TransformState &state) { // Collect the dynamic split points if provided. ArrayRef payload = state.getPayloadOps(getTarget()); SimpleRewriter rewriter(getContext()); SmallVector splitPoints; splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { auto diag = DiagnosedSilenceableFailure::success(); splitPoints = llvm::to_vector(llvm::map_range( state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { diag = emitSilenceableError() << "expected dynamic split point handle to point to a " "single-result index-typed op"; diag.attachNote(op->getLoc()) << "dynamic split point"; } return OpFoldResult(op->getResult(0)); })); if (!diag.succeeded()) return diag; if (splitPoints.size() != payload.size()) { emitError() << "expected the dynamic split point handle to point to as " "many operations (" << splitPoints.size() << ") as the target handle (" << payload.size() << ")"; return DiagnosedSilenceableFailure::definiteFailure(); } } else { splitPoints.resize(payload.size(), rewriter.getIndexAttr(getStaticSplitPoint())); } // Split each target operation. SmallVector first, second; for (const auto &pair : llvm::zip(payload, splitPoints)) { Operation *target = std::get<0>(pair); auto linalgOp = dyn_cast(target); if (!linalgOp) { auto diag = emitSilenceableError() << "only applies to structured ops"; diag.attachNote(target->getLoc()) << "target op"; return diag; } if (getDimension() >= linalgOp.getNumLoops()) { auto diag = emitSilenceableError() << "dimension " << getDimension() << " does not exist in target op"; diag.attachNote(target->getLoc()) << "target op"; return diag; } rewriter.setInsertionPoint(linalgOp); std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair)); } results.set(getFirst().cast(), first); results.set(getSecond().cast(), second); return DiagnosedSilenceableFailure::success(); } void SplitOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); if (getDynamicSplitPoint()) onlyReadsHandle(getDynamicSplitPoint(), effects); producesHandle(getResults(), effects); modifiesPayload(effects); } ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; IntegerAttr staticSplitPoint; auto pdlOperationType = pdl::OperationType::get(parser.getBuilder().getContext()); if (parser.parseOperand(target) || parser.resolveOperand(target, pdlOperationType, result.operands) || parser.parseKeyword("after")) return failure(); OptionalParseResult dynamicPointParseResult = parser.parseOptionalOperand(dynamicSplitPoint); if (!dynamicPointParseResult.hasValue()) { int64_t staticSplitPointValue; if (failed(parser.parseInteger(staticSplitPointValue))) return failure(); staticSplitPoint = parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); } else { if (failed(*dynamicPointParseResult) || parser.resolveOperand(dynamicSplitPoint, pdlOperationType, result.operands)) { return failure(); } staticSplitPoint = parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize); } result.addAttribute( SplitOp::getStaticSplitPointAttrName(result.name).getValue(), staticSplitPoint); if (failed(parser.parseOptionalAttrDict(result.attributes))) return failure(); result.addTypes({pdlOperationType, pdlOperationType}); return success(); } void SplitOp::print(OpAsmPrinter &printer) { printer << " " << getTarget() << " after "; int64_t staticSplitSize = static_cast(getStaticSplitPoint()); if (staticSplitSize != ShapedType::kDynamicSize) printer << staticSplitSize; else printer << getDynamicSplitPoint(); printer << " "; printer.printOptionalAttrDict(getOperation()->getAttrs(), {getStaticSplitPointAttrName()}); } LogicalResult SplitOp::verify() { if ((static_cast(getStaticSplitPoint()) != ShapedType::kDynamicSize) ^ (getDynamicSplitPoint() == nullptr)) { return emitOpError() << "expects either a dynamic or a static split point to be provided"; } return success(); } //===----------------------------------------------------------------------===// // SplitReductionOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { return std::pair(getSplitFactor(), getInsertSplitDimension()); }; SimpleRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr splitResult = (getUseScalingAlgorithm()) ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) : splitReduction(rewriter, target, splitFn, getUseAlloc()); if (failed(splitResult)) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); results.push_back(splitResult->initOrAlloc); results.push_back(splitResult->fillOp); results.push_back(splitResult->splitLinalgOp); results.push_back(splitResult->resultCombiningLinalgOp); return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// // TileOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::TileOp::apply(TransformResults &transformResults, TransformState &state) { LinalgTilingOptions tilingOptions; SmallVector tileSizes = extractFromI64ArrayAttr(getStaticSizes()); ArrayRef targets = state.getPayloadOps(getTarget()); SmallVector> dynamicSizeProducers; dynamicSizeProducers.reserve(getDynamicSizes().size()); for (Value dynamicSizeProducerHandle : getDynamicSizes()) { dynamicSizeProducers.push_back( state.getPayloadOps(dynamicSizeProducerHandle)); if (dynamicSizeProducers.back().size() != targets.size()) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected as many dynamic size-producing operations (" << dynamicSizeProducers.back().size() << ") as target ops (" << targets.size() << ")"; diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; return diag; } for (Operation *op : dynamicSizeProducers.back()) { if (op->getNumResults() == 1 && op->getResult(0).getType().isa()) continue; DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected sizes to be produced by ops " "with a single index-type result"; diag.attachNote(op->getLoc()) << "size producer op"; diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; return diag; } } SmallVector tiled; SmallVector, 4> loops; loops.resize(getLoops().size()); for (auto &en : llvm::enumerate(targets)) { auto linalgOp = dyn_cast(en.value()); if (!linalgOp) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "only linalg ops are supported"; diag.attachNote(en.value()->getLoc()) << "target op"; return diag; } unsigned index = en.index(); if (!tileSizes.empty()) { tilingOptions.setTileSizeComputationFunction( [&, index](OpBuilder &b, Operation *) { SmallVector sizes; sizes.reserve(tileSizes.size()); unsigned dynamicIdx = 0; for (OpFoldResult ofr : getMixedSizes()) { if (auto attr = ofr.dyn_cast()) { sizes.push_back(b.create( getLoc(), attr.cast().getInt())); } else { sizes.push_back( dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); } } return sizes; }); } tilingOptions.setInterchange(extractUIntArray(getInterchange())); LinalgTilingPattern pattern(getContext(), tilingOptions); SimpleRewriter rewriter(linalgOp.getContext()); FailureOr tiledOp = pattern.returningMatchAndRewrite(linalgOp, rewriter); if (failed(tiledOp)) return DiagnosedSilenceableFailure::definiteFailure(); tiled.push_back(tiledOp->op); for (const auto &en2 : llvm::enumerate(tiledOp->loops)) loops[en2.index()].push_back(en2.value()); } transformResults.set(getTiledLinalgOp().cast(), tiled); for (const auto &en : llvm::enumerate(loops)) transformResults.set(getLoops()[en.index()].cast(), en.value()); return DiagnosedSilenceableFailure::success(); } SmallVector transform::TileOp::getMixedSizes() { ValueRange dynamic = getDynamicSizes(); SmallVector tileSizes = extractFromI64ArrayAttr(getStaticSizes()); SmallVector results; results.reserve(tileSizes.size()); unsigned dynamicPos = 0; Builder builder(getContext()); for (int64_t size : tileSizes) { if (size == ShapedType::kDynamicSize) { results.push_back(dynamic[dynamicPos++]); } else { results.push_back(builder.getIndexAttr(size)); } } return results; } ParseResult transform::TileOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target; SmallVector dynamicSizes; ArrayAttr staticSizes; auto pdlOperationType = pdl::OperationType::get(parser.getContext()); if (parser.parseOperand(target) || parser.resolveOperand(target, pdlOperationType, result.operands) || parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) || parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || parser.parseOptionalAttrDict(result.attributes)) return ParseResult::failure(); result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); size_t numExpectedLoops = staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0); result.addTypes(SmallVector(numExpectedLoops + 1, pdlOperationType)); return success(); } void TileOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(), getStaticSizes()); p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); } void transform::TileOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); onlyReadsHandle(getDynamicSizes(), effects); producesHandle(getTiledLinalgOp(), effects); producesHandle(getLoops(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // TileToForeachThreadOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne( TilingInterface target, SmallVectorImpl &results, transform::TransformState &state) { IRRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); auto maybeThreadDimMappingAttr = getThreadDimMapping(); FailureOr tilingResult = linalg::tileToForeachThreadOp( rewriter, target, getAsOpFoldResult(getNumThreads()), maybeThreadDimMappingAttr ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) : ArrayRef{}); if (failed(tilingResult)) return emitDefaultSilenceableFailure(target); rewriter.replaceOp(target, tilingResult->tileOp->getResults()); results.assign({tilingResult->tileOp, tilingResult->tiledOp}); return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// // VectorizeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::VectorizeOp::applyToOne(Operation *target, SmallVectorImpl &results, transform::TransformState &state) { if (!target->hasTrait()) { auto diag = this->emitOpError("requires isolated-from-above targets"); diag.attachNote(target->getLoc()) << "non-isolated target"; return DiagnosedSilenceableFailure::definiteFailure(); } MLIRContext *ctx = getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); vector::populateVectorReductionToContractPatterns(patterns); patterns.add(ctx, /*benefit=*/2); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); if (getVectorizePadding()) linalg::populatePadOpVectorizationPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); results.push_back(target); return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// namespace { /// Registers new ops and declares PDL as dependent dialect since the additional /// ops are using PDL types for operands and results. class LinalgTransformDialectExtension : public transform::TransformDialectExtension< LinalgTransformDialectExtension> { public: LinalgTransformDialectExtension() { declareDependentDialect(); declareDependentDialect(); declareDependentDialect(); declareDependentDialect(); declareDependentDialect(); registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" >(); } }; } // namespace #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" void mlir::linalg::registerTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); }