1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 10 11 #include "mlir/Dialect/Linalg/IR/Linalg.h" 12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 13 #include "mlir/Dialect/PDL/IR/PDL.h" 14 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 15 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16 #include "mlir/Parser/Parser.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 using namespace mlir; 20 using namespace mlir::linalg; 21 using namespace mlir::transform; 22 23 /// Extracts a vector of int64_t from an array attribute. Asserts if the 24 /// attribute contains values other than integers. 25 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) { 26 SmallVector<int64_t> result; 27 result.reserve(attr.size()); 28 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 29 result.push_back(value.getSExtValue()); 30 return result; 31 } 32 33 /// Extracts a vector of unsigned from an array attribute. Asserts if the 34 /// attribute contains values other than intergers. May truncate. 35 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) { 36 SmallVector<unsigned> result; 37 result.reserve(attr.size()); 38 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 39 result.push_back(value.getZExtValue()); 40 return result; 41 } 42 43 namespace { 44 /// A simple pattern rewriter that implements no special logic. 45 class SimpleRewriter : public PatternRewriter { 46 public: 47 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 48 }; 49 } // namespace 50 51 /// Attempts to apply the pattern specified as template argument to the given 52 /// operation. The pattern is expected to have a `returningMatchAndRewrite` 53 /// function that returns the "main" result or failure. Returns failure if the 54 /// pattern failed to apply. Extra arguments are forwarded to the pattern 55 /// constructor. 56 template <typename PatternTy, typename... Args> 57 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { 58 // Check if the given operation has the type expected by the pattern. 59 using OpTy = typename llvm::function_traits< 60 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; 61 auto op = dyn_cast<OpTy>(operation); 62 if (!op) 63 return failure(); 64 65 // Apply the pattern directly to the op. 66 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); 67 SimpleRewriter rewriter(operation->getContext()); 68 rewriter.setInsertionPoint(operation); 69 auto result = pattern.returningMatchAndRewrite(op, rewriter); 70 if (failed(result)) 71 return failure(); 72 return cast<LinalgOp>(result->getOperation()); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // DecomposeOp 77 //===----------------------------------------------------------------------===// 78 79 FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target, 80 TransformState &state) { 81 FailureOr<LinalgOp> windowed = 82 tryApply<DownscaleSizeOneWindowed2DConvolution>(target); 83 if (succeeded(windowed)) 84 return windowed; 85 86 FailureOr<LinalgOp> depthwise = 87 tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target); 88 if (succeeded(depthwise)) 89 return depthwise; 90 91 return reportUnknownTransformError(target); 92 } 93 94 //===----------------------------------------------------------------------===// 95 // FuseOp 96 //===----------------------------------------------------------------------===// 97 98 /// Apply a tiling transformation to all payload ops and store both the 99 /// tiled operation as well as the created tile loops. 100 static LogicalResult 101 applyTilingToAll(Operation *transformOp, Value target, 102 ArrayRef<int64_t> tileSizes, 103 transform::TransformResults &transformResults, 104 transform::TransformState &state, 105 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) { 106 // Number of loops: Number of tiles sizes that are not zero. 107 size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0); 108 // All payload ops. These should all be LinalgOps for now. 109 ArrayRef<Operation *> payloadOps = state.getPayloadOps(target); 110 111 SmallVector<Operation *> tiledLinalgOps; 112 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 113 for (unsigned int i = 0; i < numLoops; ++i) 114 loopOps[i].reserve(payloadOps.size()); 115 116 for (Operation *target : payloadOps) { 117 auto linalgOp = dyn_cast<linalg::LinalgOp>(target); 118 if (!linalgOp) 119 return transformOp->emitError("only LinalgOps are supported"); 120 121 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp); 122 if (failed(tiled)) 123 return failure(); 124 125 tiledLinalgOps.push_back(tiled->op); 126 if (tiled->loops.size() != numLoops) 127 // Not enough loops were generated. This usually means that the input size 128 // was smaller than the tiling size. 129 // TODO: LinalgTilingPattern should return failure(). 130 return failure(); 131 for (unsigned int i = 0; i < numLoops; ++i) 132 loopOps[i].push_back(tiled->loops[i]); 133 } 134 135 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 136 for (unsigned int i = 0; i < numLoops; ++i) 137 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 138 return success(); 139 } 140 141 /// Parse a tiling-like operation that returns the tiled op as well as the 142 /// created tile loops. The function counts the non-zero tile sizes to compute 143 /// the number of results. 144 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, 145 StringRef sizesAttrName) { 146 OpAsmParser::UnresolvedOperand targetOperand; 147 SMLoc opLoc = parser.getCurrentLocation(); 148 if (parser.parseOperand(targetOperand) || 149 parser.parseOptionalAttrDict(result.attributes)) 150 return failure(); 151 Attribute sizesAttr = result.attributes.get(sizesAttrName); 152 if (!sizesAttr) 153 return parser.emitError(opLoc) 154 << "expected '" << sizesAttrName << "' attribute"; 155 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); 156 if (!sizesArrayAttr) 157 return parser.emitError(opLoc) 158 << "'" << sizesAttrName << "' attribute must be an array"; 159 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); 160 size_t numExpectedLoops = 161 sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0); 162 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); 163 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) 164 return failure(); 165 return success(); 166 } 167 168 DiagnosedSilenceableFailure 169 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, 170 mlir::transform::TransformState &state) { 171 LinalgTilingAndFusionOptions fusionOptions; 172 fusionOptions.tileSizes = extractI64Array(getTileSizes()); 173 fusionOptions.tileInterchange = extractI64Array(getTileInterchange()); 174 175 LogicalResult result = applyTilingToAll( 176 getOperation(), getTarget(), fusionOptions.tileSizes, transformResults, 177 state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> { 178 LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); 179 SimpleRewriter rewriter(getContext()); 180 rewriter.setInsertionPoint(linalgOp); 181 FailureOr<TileLoopNest> tileLoopNest = 182 pattern.returningMatchAndRewrite(linalgOp, rewriter); 183 if (failed(tileLoopNest)) 184 return failure(); 185 186 TiledLinalgOp tiledLinalgOp; 187 tiledLinalgOp.op = tileLoopNest->getRootOp(); 188 tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(), 189 tileLoopNest->getLoopOps().end()}; 190 return tiledLinalgOp; 191 }); 192 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() 193 : DiagnosedSilenceableFailure::success(); 194 } 195 196 ParseResult transform::FuseOp::parse(OpAsmParser &parser, 197 OperationState &result) { 198 return parseTileLikeOp( 199 parser, result, 200 transform::FuseOp::getTileSizesAttrName(result.name).getValue()); 201 } 202 203 void transform::FuseOp::print(OpAsmPrinter &p) { 204 p << ' '; 205 p << getTarget(); 206 p.printOptionalAttrDict((*this)->getAttrs()); 207 } 208 209 LogicalResult transform::FuseOp::verify() { 210 SmallVector<int64_t> permutation = extractI64Array(getTileInterchange()); 211 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); 212 if (!std::is_permutation(sequence.begin(), sequence.end(), 213 permutation.begin(), permutation.end())) { 214 return emitOpError() << "expects interchange to be a permutation, found " 215 << getTileInterchange(); 216 } 217 return success(); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // GeneralizeOp 222 //===----------------------------------------------------------------------===// 223 224 FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target, 225 TransformState &state) { 226 // Exit early if no transformation is needed. 227 if (isa<GenericOp>(target)) 228 return target; 229 230 FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target); 231 if (succeeded(generic)) 232 return generic; 233 234 return reportUnknownTransformError(target); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // InterchangeOp 239 //===----------------------------------------------------------------------===// 240 241 FailureOr<LinalgOp> 242 transform::InterchangeOp::applyToOne(LinalgOp target, TransformState &state) { 243 SmallVector<unsigned> interchangeVector = 244 extractUIntArray(getIteratorInterchange()); 245 // Exit early if no transformation is needed. 246 if (interchangeVector.empty()) 247 return target; 248 249 auto genericTarget = dyn_cast<GenericOp>(target.getOperation()); 250 if (!genericTarget) { 251 InFlightDiagnostic diag = emitOpError() 252 << "applies to " << GenericOp::getOperationName() 253 << " ops"; 254 diag.attachNote(target.getLoc()) << "attempted to apply to this op"; 255 return diag; 256 } 257 258 return tryApply<GenericOpInterchangePattern>(target, interchangeVector); 259 } 260 261 LogicalResult transform::InterchangeOp::verify() { 262 SmallVector<unsigned> permutation = 263 extractUIntArray(getIteratorInterchange()); 264 auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size())); 265 if (!std::is_permutation(sequence.begin(), sequence.end(), 266 permutation.begin(), permutation.end())) { 267 return emitOpError() 268 << "expects iterator_interchange to be a permutation, found " 269 << getIteratorInterchange(); 270 } 271 return success(); 272 } 273 274 //===---------------------------------------------------------------------===// 275 // PadOp 276 //===---------------------------------------------------------------------===// 277 278 FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target, 279 TransformState &state) { 280 // Convert the integer packing flags to booleans. 281 SmallVector<bool> packPaddings; 282 for (int64_t packPadding : extractI64Array(getPackPaddings())) 283 packPaddings.push_back(static_cast<bool>(packPadding)); 284 285 // Convert the padding values to attributes. 286 SmallVector<Attribute> paddingValues; 287 for (auto const &it : 288 llvm::zip(getPaddingValues(), target->getOperandTypes())) { 289 Attribute attr = std::get<0>(it); 290 Type elementType = getElementTypeOrSelf(std::get<1>(it)); 291 // Try to parse string attributes to obtain an attribute of element type. 292 if (auto stringAttr = attr.dyn_cast<StringAttr>()) { 293 paddingValues.push_back( 294 parseAttribute(attr.cast<StringAttr>(), elementType)); 295 if (!paddingValues.back()) { 296 InFlightDiagnostic diag = emitOpError() 297 << "expects a padding value that parses to " 298 << elementType << ", got " << std::get<0>(it); 299 diag.attachNote(target.getLoc()) << "when applied to this op"; 300 return diag; 301 } 302 continue; 303 } 304 // Otherwise, add the attribute directly. 305 if (attr.getType() != elementType) { 306 InFlightDiagnostic diag = emitOpError() 307 << "expects a padding value of type " 308 << elementType << ", got " << attr; 309 diag.attachNote(target.getLoc()) << "when applied to this op"; 310 return diag; 311 } 312 paddingValues.push_back(attr); 313 } 314 315 // Extract the transpose vectors. 316 SmallVector<SmallVector<int64_t>> transposePaddings; 317 for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>()) 318 transposePaddings.push_back( 319 extractI64Array(transposeVector.cast<ArrayAttr>())); 320 321 LinalgPaddingOptions paddingOptions; 322 paddingOptions.setPaddingValues(paddingValues); 323 paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions())); 324 paddingOptions.setPackPaddings(packPaddings); 325 paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings())); 326 paddingOptions.setTransposePaddings(transposePaddings); 327 328 FailureOr<LinalgOp> result = 329 tryApply<LinalgPaddingPattern>(target, paddingOptions); 330 if (succeeded(result)) 331 return result; 332 333 InFlightDiagnostic diag = emitError() 334 << "failed to apply pattern to target op"; 335 diag.attachNote(target.getLoc()) << "target op"; 336 return diag; 337 } 338 339 LogicalResult transform::PadOp::verify() { 340 SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings()); 341 if (any_of(packPaddings, [](int64_t packPadding) { 342 return packPadding != 0 && packPadding != 1; 343 })) { 344 return emitOpError() 345 << "expects pack_paddings to contain booleans (0/1), found " 346 << getPackPaddings(); 347 } 348 349 SmallVector<int64_t> paddingDimensions = 350 extractI64Array(getPaddingDimensions()); 351 if (any_of(paddingDimensions, 352 [](int64_t paddingDimension) { return paddingDimension < 0; })) { 353 return emitOpError() 354 << "expects padding_dimensions to contain positive integers, found " 355 << getPaddingDimensions(); 356 } 357 358 SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings()); 359 if (any_of(hoistPaddings, 360 [](int64_t hoistPadding) { return hoistPadding < 0; })) { 361 return emitOpError() 362 << "expects hoist_paddings to contain positive integers, found " 363 << getHoistPaddings(); 364 } 365 366 ArrayAttr transposes = getTransposePaddings(); 367 for (Attribute attr : transposes) { 368 SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr); 369 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); 370 if (!std::is_permutation(sequence.begin(), sequence.end(), 371 transpose.begin(), transpose.end())) { 372 return emitOpError() 373 << "expects transpose_paddings to be a permutation, found " 374 << attr; 375 } 376 } 377 return success(); 378 } 379 380 //===----------------------------------------------------------------------===// 381 // ScalarizeOp 382 //===----------------------------------------------------------------------===// 383 384 FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target, 385 TransformState &state) { 386 LinalgTilingOptions tilingOptions; 387 tilingOptions.scalarizeDynamicDims(); 388 // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile 389 // sizes and asserts that it is not already set. 390 SmallVector<int64_t> emptyTileSizes; 391 LinalgTilingPattern pattern(getContext(), tilingOptions); 392 SimpleRewriter rewriter(getContext()); 393 rewriter.setInsertionPoint(target); 394 FailureOr<TiledLinalgOp> result = 395 pattern.returningMatchAndRewrite(target, rewriter); 396 if (failed(result)) 397 return failure(); 398 399 return result->op; 400 } 401 402 //===----------------------------------------------------------------------===// 403 // SplitOp 404 //===----------------------------------------------------------------------===// 405 406 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, 407 TransformState &state) { 408 // Collect the dynamic split points if provided. 409 ArrayRef<Operation *> payload = state.getPayloadOps(getTarget()); 410 SimpleRewriter rewriter(getContext()); 411 SmallVector<OpFoldResult> splitPoints; 412 splitPoints.reserve(payload.size()); 413 if (getDynamicSplitPoint()) { 414 auto diag = DiagnosedSilenceableFailure::success(); 415 splitPoints = llvm::to_vector(llvm::map_range( 416 state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { 417 if (op->getNumResults() != 1 || 418 !op->getResult(0).getType().isIndex()) { 419 diag = emitSilenceableError() 420 << "expected dynamic split point handle to point to a " 421 "single-result index-typed op"; 422 diag.attachNote(op->getLoc()) << "dynamic split point"; 423 } 424 return OpFoldResult(op->getResult(0)); 425 })); 426 if (!diag.succeeded()) 427 return diag; 428 429 if (splitPoints.size() != payload.size()) { 430 emitError() << "expected the dynamic split point handle to point to as " 431 "many operations (" 432 << splitPoints.size() << ") as the target handle (" 433 << payload.size() << ")"; 434 return DiagnosedSilenceableFailure::definiteFailure(); 435 } 436 } else { 437 splitPoints.resize(payload.size(), 438 rewriter.getIndexAttr(getStaticSplitPoint())); 439 } 440 441 // Split each target operation. 442 SmallVector<Operation *> first, second; 443 for (const auto &pair : llvm::zip(payload, splitPoints)) { 444 Operation *target = std::get<0>(pair); 445 auto linalgOp = dyn_cast<LinalgOp>(target); 446 if (!linalgOp) { 447 auto diag = emitSilenceableError() << "only applies to structured ops"; 448 diag.attachNote(target->getLoc()) << "target op"; 449 return diag; 450 } 451 452 if (getDimension() >= linalgOp.getNumLoops()) { 453 auto diag = emitSilenceableError() << "dimension " << getDimension() 454 << " does not exist in target op"; 455 diag.attachNote(target->getLoc()) << "target op"; 456 return diag; 457 } 458 459 rewriter.setInsertionPoint(linalgOp); 460 std::tie(first.emplace_back(), second.emplace_back()) = 461 linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair)); 462 } 463 464 results.set(getFirst().cast<OpResult>(), first); 465 results.set(getSecond().cast<OpResult>(), second); 466 return DiagnosedSilenceableFailure::success(); 467 } 468 469 void SplitOp::getEffects( 470 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 471 // The target handle is consumed. 472 effects.emplace_back(MemoryEffects::Read::get(), getTarget(), 473 TransformMappingResource::get()); 474 effects.emplace_back(MemoryEffects::Free::get(), getTarget(), 475 TransformMappingResource::get()); 476 477 // The dynamic split point handle is not consumed. 478 if (getDynamicSplitPoint()) { 479 effects.emplace_back(MemoryEffects::Read::get(), getDynamicSplitPoint(), 480 TransformMappingResource::get()); 481 } 482 483 // The resulting handles are produced. 484 for (Value result : getResults()) { 485 effects.emplace_back(MemoryEffects::Allocate::get(), result, 486 TransformMappingResource::get()); 487 effects.emplace_back(MemoryEffects::Write::get(), result, 488 TransformMappingResource::get()); 489 } 490 491 effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); 492 effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); 493 } 494 495 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { 496 OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; 497 IntegerAttr staticSplitPoint; 498 auto pdlOperationType = 499 pdl::OperationType::get(parser.getBuilder().getContext()); 500 if (parser.parseOperand(target) || 501 parser.resolveOperand(target, pdlOperationType, result.operands) || 502 parser.parseKeyword("after")) 503 return failure(); 504 505 OptionalParseResult dynamicPointParseResult = 506 parser.parseOptionalOperand(dynamicSplitPoint); 507 if (!dynamicPointParseResult.hasValue()) { 508 int64_t staticSplitPointValue; 509 if (failed(parser.parseInteger(staticSplitPointValue))) 510 return failure(); 511 512 staticSplitPoint = 513 parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); 514 } else { 515 if (failed(*dynamicPointParseResult) || 516 parser.resolveOperand(dynamicSplitPoint, pdlOperationType, 517 result.operands)) { 518 return failure(); 519 } 520 521 staticSplitPoint = 522 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize); 523 } 524 525 result.addAttribute( 526 SplitOp::getStaticSplitPointAttrName(result.name).getValue(), 527 staticSplitPoint); 528 if (failed(parser.parseOptionalAttrDict(result.attributes))) 529 return failure(); 530 531 result.addTypes({pdlOperationType, pdlOperationType}); 532 return success(); 533 } 534 535 void SplitOp::print(OpAsmPrinter &printer) { 536 printer << " " << getTarget() << " after "; 537 int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint()); 538 if (staticSplitSize != ShapedType::kDynamicSize) 539 printer << staticSplitSize; 540 else 541 printer << getDynamicSplitPoint(); 542 printer << " "; 543 printer.printOptionalAttrDict(getOperation()->getAttrs(), 544 {getStaticSplitPointAttrName()}); 545 } 546 547 LogicalResult SplitOp::verify() { 548 if ((static_cast<int64_t>(getStaticSplitPoint()) != 549 ShapedType::kDynamicSize) ^ 550 (getDynamicSplitPoint() == nullptr)) { 551 return emitOpError() 552 << "expects either a dynamic or a static split point to be provided"; 553 } 554 return success(); 555 } 556 557 //===----------------------------------------------------------------------===// 558 // SplitReductionOp 559 //===----------------------------------------------------------------------===// 560 561 FailureOr<SmallVector<Operation *>> 562 transform::SplitReductionOp::applyToOne(LinalgOp target, 563 TransformState &state) { 564 ControlSplitReductionFn splitFn = [&](LinalgOp) { 565 return std::pair<int64_t, unsigned>(getSplitFactor(), 566 getInsertSplitDimension()); 567 }; 568 SimpleRewriter rewriter(getContext()); 569 rewriter.setInsertionPoint(target); 570 FailureOr<SplitReductionResult> splitResult = 571 (getUseScalingAlgorithm()) 572 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) 573 : splitReduction(rewriter, target, splitFn, getUseAlloc()); 574 if (failed(splitResult)) 575 return getOperation()->emitError("failed to apply"); 576 return SmallVector<Operation *>{splitResult->fillOp, 577 splitResult->splitLinalgOp, 578 splitResult->resultCombiningLinalgOp}; 579 } 580 581 //===----------------------------------------------------------------------===// 582 // TileOp 583 //===----------------------------------------------------------------------===// 584 585 DiagnosedSilenceableFailure 586 transform::TileOp::apply(TransformResults &transformResults, 587 TransformState &state) { 588 LinalgTilingOptions tilingOptions; 589 SmallVector<int64_t> tileSizes = extractI64Array(getSizes()); 590 591 if (!tileSizes.empty()) 592 tilingOptions.setTileSizes(tileSizes); 593 tilingOptions.setInterchange(extractUIntArray(getInterchange())); 594 LinalgTilingPattern pattern(getContext(), tilingOptions); 595 596 LogicalResult result = applyTilingToAll( 597 getOperation(), getTarget(), tileSizes, transformResults, state, 598 [&](LinalgOp linalgOp) { 599 SimpleRewriter rewriter(linalgOp.getContext()); 600 return pattern.returningMatchAndRewrite(linalgOp, rewriter); 601 }); 602 return DiagnosedSilenceableFailure(result); 603 } 604 605 ParseResult transform::TileOp::parse(OpAsmParser &parser, 606 OperationState &result) { 607 return parseTileLikeOp(parser, result, 608 TileOp::getSizesAttrName(result.name).getValue()); 609 } 610 611 void TileOp::print(OpAsmPrinter &p) { 612 p << ' '; 613 p << getTarget(); 614 p.printOptionalAttrDict((*this)->getAttrs()); 615 } 616 617 //===----------------------------------------------------------------------===// 618 // VectorizeOp 619 //===----------------------------------------------------------------------===// 620 621 FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target, 622 TransformState &state) { 623 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 624 InFlightDiagnostic diag = emitOpError() 625 << "applies only to isolated-from-above targets"; 626 diag.attachNote(target->getLoc()) << "non-isolated target"; 627 return diag; 628 } 629 630 MLIRContext *ctx = getContext(); 631 RewritePatternSet patterns(ctx); 632 patterns.add<LinalgVectorizationPattern>(ctx); 633 634 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 635 vector::populateVectorReductionToContractPatterns(patterns); 636 patterns.add<linalg::LinalgCopyVTRForwardingPattern, 637 linalg::LinalgCopyVTWForwardingPattern>(ctx, 638 /*benefit=*/2); 639 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); 640 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); 641 if (getVectorizePadding()) 642 linalg::populatePadOpVectorizationPatterns(patterns); 643 644 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) 645 return reportUnknownTransformError(target); 646 return target; 647 } 648 649 //===----------------------------------------------------------------------===// 650 // Transform op registration 651 //===----------------------------------------------------------------------===// 652 653 namespace { 654 /// Registers new ops and declares PDL as dependent dialect since the additional 655 /// ops are using PDL types for operands and results. 656 class LinalgTransformDialectExtension 657 : public transform::TransformDialectExtension< 658 LinalgTransformDialectExtension> { 659 public: 660 LinalgTransformDialectExtension() { 661 declareDependentDialect<pdl::PDLDialect>(); 662 declareDependentDialect<scf::SCFDialect>(); 663 declareDependentDialect<vector::VectorDialect>(); 664 registerTransformOps< 665 #define GET_OP_LIST 666 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 667 >(); 668 } 669 }; 670 } // namespace 671 672 #define GET_OP_CLASSES 673 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 674 675 void mlir::linalg::registerTransformDialectExtension( 676 DialectRegistry ®istry) { 677 registry.addExtensions<LinalgTransformDialectExtension>(); 678 } 679