//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Parser/Parser.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/FormatVariadic.h" using namespace mlir; using namespace mlir::linalg; using namespace mlir::transform; /// Extracts a vector of int64_t from an array attribute. Asserts if the /// attribute contains values other than integers. static SmallVector extractI64Array(ArrayAttr attr) { SmallVector result; result.reserve(attr.size()); for (APInt value : attr.getAsValueRange()) result.push_back(value.getSExtValue()); return result; } /// Extracts a vector of unsigned from an array attribute. Asserts if the /// attribute contains values other than intergers. May truncate. static SmallVector extractUIntArray(ArrayAttr attr) { SmallVector result; result.reserve(attr.size()); for (APInt value : attr.getAsValueRange()) result.push_back(value.getZExtValue()); return result; } namespace { /// A simple pattern rewriter that implements no special logic. class SimpleRewriter : public PatternRewriter { public: SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} }; } // namespace //===----------------------------------------------------------------------===// // InterchangeOp //===----------------------------------------------------------------------===// FailureOr transform::InterchangeOp::applyToOne(LinalgOp target) { SmallVector interchangeVector = extractUIntArray(getIteratorInterchange()); // Exit early if no transformation is needed. if (interchangeVector.empty()) return target; auto genericTarget = dyn_cast(target.getOperation()); if (!genericTarget) { InFlightDiagnostic diag = emitOpError() << "applies to " << GenericOp::getOperationName() << " ops"; diag.attachNote(target.getLoc()) << "attempted to apply to this op"; return diag; } GenericOpInterchangePattern pattern(getContext(), interchangeVector); SimpleRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr result = pattern.returningMatchAndRewrite(genericTarget, rewriter); if (failed(result)) return failure(); return cast(result->getOperation()); } LogicalResult transform::InterchangeOp::verify() { SmallVector permutation = extractUIntArray(getIteratorInterchange()); auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), permutation.begin(), permutation.end())) { return emitOpError() << "expects iterator_interchange to be a permutation, found " << getIteratorInterchange(); } return success(); } //===---------------------------------------------------------------------===// // PadOp //===---------------------------------------------------------------------===// FailureOr transform::PadOp::applyToOne(LinalgOp target) { // Convert the integer packing flags to booleans. SmallVector packPaddings; for (int64_t packPadding : extractI64Array(getPackPaddings())) packPaddings.push_back(static_cast(packPadding)); // Convert the padding values to attributes. SmallVector paddingValues; for (auto const &it : llvm::zip(getPaddingValues(), target->getOperandTypes())) { Attribute attr = std::get<0>(it); Type elementType = getElementTypeOrSelf(std::get<1>(it)); // Try to parse string attributes to obtain an attribute of element type. if (auto stringAttr = attr.dyn_cast()) { paddingValues.push_back( parseAttribute(attr.cast(), elementType)); if (!paddingValues.back()) { InFlightDiagnostic diag = emitOpError() << "expects a padding value that parses to " << elementType << ", got " << std::get<0>(it); diag.attachNote(target.getLoc()) << "when applied to this op"; return diag; } continue; } // Otherwise, add the attribute directly. if (attr.getType() != elementType) { InFlightDiagnostic diag = emitOpError() << "expects a padding value of type " << elementType << ", got " << attr; diag.attachNote(target.getLoc()) << "when applied to this op"; return diag; } paddingValues.push_back(attr); } // Extract the transpose vectors. SmallVector> transposePaddings; for (Attribute transposeVector : getTransposePaddings().cast()) transposePaddings.push_back( extractI64Array(transposeVector.cast())); LinalgPaddingOptions paddingOptions; paddingOptions.setPaddingValues(paddingValues); paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions())); paddingOptions.setPackPaddings(packPaddings); paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings())); paddingOptions.setTransposePaddings(transposePaddings); LinalgPaddingPattern pattern(getContext(), paddingOptions); SimpleRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr patternResult = pattern.returningMatchAndRewrite(target, rewriter); if (failed(patternResult)) { InFlightDiagnostic diag = emitError() << "failed to apply pattern to target op"; diag.attachNote(target.getLoc()) << "target op"; return diag; } return patternResult; } LogicalResult transform::PadOp::verify() { SmallVector packPaddings = extractI64Array(getPackPaddings()); if (any_of(packPaddings, [](int64_t packPadding) { return packPadding != 0 && packPadding != 1; })) { return emitOpError() << "expects pack_paddings to contain booleans (0/1), found " << getPackPaddings(); } SmallVector paddingDimensions = extractI64Array(getPaddingDimensions()); if (any_of(paddingDimensions, [](int64_t paddingDimension) { return paddingDimension < 0; })) { return emitOpError() << "expects padding_dimensions to contain positive integers, found " << getPaddingDimensions(); } SmallVector hoistPaddings = extractI64Array(getHoistPaddings()); if (any_of(hoistPaddings, [](int64_t hoistPadding) { return hoistPadding < 0; })) { return emitOpError() << "expects hoist_paddings to contain positive integers, found " << getHoistPaddings(); } ArrayAttr transposes = getTransposePaddings(); for (Attribute attr : transposes) { SmallVector transpose = extractFromI64ArrayAttr(attr); auto sequence = llvm::to_vector(llvm::seq(0, transpose.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(), transpose.end())) { return emitOpError() << "expects transpose_paddings to be a permutation, found " << attr; } } return success(); } //===----------------------------------------------------------------------===// // ScalarizeOp //===----------------------------------------------------------------------===// FailureOr transform::ScalarizeOp::applyToOne(LinalgOp target) { LinalgTilingOptions tilingOptions; tilingOptions.scalarizeDynamicDims(); // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile // sizes and asserts that it is not already set. SmallVector emptyTileSizes; LinalgTilingPattern pattern(getContext(), tilingOptions); SimpleRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr result = pattern.returningMatchAndRewrite(target, rewriter); if (failed(result)) return failure(); return result->op; } //===----------------------------------------------------------------------===// // TileOp //===----------------------------------------------------------------------===// /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. static LogicalResult applyTilingToAll(Operation *transformOp, Value target, ArrayRef tileSizes, transform::TransformResults &transformResults, transform::TransformState &state, function_ref(LinalgOp)> applyFn) { // Number of loops: Number of tiles sizes that are not zero. size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0); // All payload ops. These should all be LinalgOps for now. ArrayRef payloadOps = state.getPayloadOps(target); SmallVector tiledLinalgOps; SmallVector> loopOps(numLoops); for (unsigned int i = 0; i < numLoops; ++i) loopOps[i].reserve(payloadOps.size()); for (Operation *target : payloadOps) { auto linalgOp = dyn_cast(target); if (!linalgOp) return transformOp->emitError("only LinalgOps are supported"); FailureOr tiled = applyFn(linalgOp); if (failed(tiled)) return failure(); tiledLinalgOps.push_back(tiled->op); if (tiled->loops.size() != numLoops) // Not enough loops were generated. This usually means that the input size // was smaller than the tiling size. // TODO: LinalgTilingPattern should return failure(). return failure(); for (unsigned int i = 0; i < numLoops; ++i) loopOps[i].push_back(tiled->loops[i]); } transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); for (unsigned int i = 0; i < numLoops; ++i) transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); return success(); } LogicalResult transform::TileOp::apply(TransformResults &transformResults, TransformState &state) { LinalgTilingOptions tilingOptions; SmallVector tileSizes = extractI64Array(getSizes()); if (!tileSizes.empty()) tilingOptions.setTileSizes(tileSizes); tilingOptions.setInterchange(extractUIntArray(getInterchange())); LinalgTilingPattern pattern(getContext(), tilingOptions); return applyTilingToAll(getOperation(), getTarget(), tileSizes, transformResults, state, [&](LinalgOp linalgOp) { SimpleRewriter rewriter(linalgOp.getContext()); return pattern.returningMatchAndRewrite(linalgOp, rewriter); }); } ParseResult transform::TileOp::parse(OpAsmParser &parser, OperationState &result) { StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue(); OpAsmParser::UnresolvedOperand targetOperand; SMLoc opLoc = parser.getCurrentLocation(); if (parser.parseOperand(targetOperand) || parser.parseOptionalAttrDict(result.attributes)) return failure(); Attribute sizesAttr = result.attributes.get(sizesAttrName); if (!sizesAttr) return parser.emitError(opLoc) << "expected '" << sizesAttrName << "' attribute"; auto sizesArrayAttr = sizesAttr.dyn_cast(); if (!sizesArrayAttr) return parser.emitError(opLoc) << "'" << sizesAttrName << "' attribute must be an array"; Type pdlOpType = parser.getBuilder().getType(); size_t numExpectedLoops = sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0); result.addTypes(SmallVector(numExpectedLoops + 1, pdlOpType)); if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) return failure(); return success(); } void TileOp::print(OpAsmPrinter &p) { p << ' '; p << getTarget(); p.printOptionalAttrDict((*this)->getAttrs()); } void TileOp::getEffects( SmallVectorImpl> &effects) { // `target` arg is consumed and can no longer be used. effects.emplace_back(MemoryEffects::Read::get(), getTarget(), TransformMappingResource::get()); effects.emplace_back(MemoryEffects::Free::get(), getTarget(), TransformMappingResource::get()); for (Value r : getResults()) { effects.emplace_back(MemoryEffects::Write::get(), r, TransformMappingResource::get()); effects.emplace_back(MemoryEffects::Allocate::get(), r, TransformMappingResource::get()); } effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); } //===----------------------------------------------------------------------===// // VectorizeOp //===----------------------------------------------------------------------===// FailureOr VectorizeOp::applyToOne(Operation *target) { if (!target->hasTrait()) { InFlightDiagnostic diag = emitOpError() << "applies only to isolated-from-above targets"; diag.attachNote(target->getLoc()) << "non-isolated target"; return diag; } MLIRContext *ctx = getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); vector::populateVectorReductionToContractPatterns(patterns); patterns.add(ctx, /*benefit=*/2); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); if (getVectorizePadding()) linalg::populatePadOpVectorizationPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) { InFlightDiagnostic diag = emitError() << "failed to apply"; diag.attachNote(target->getLoc()) << "target op"; return diag; } return target; } //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// namespace { /// Registers new ops and declares PDL as dependent dialect since the additional /// ops are using PDL types for operands and results. class LinalgTransformDialectExtension : public transform::TransformDialectExtension< LinalgTransformDialectExtension> { public: LinalgTransformDialectExtension() { declareDependentDialect(); declareDependentDialect(); declareDependentDialect(); registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" >(); } }; } // namespace #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" void mlir::linalg::registerTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); }