1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 10 11 #include "mlir/Dialect/Linalg/IR/Linalg.h" 12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 13 #include "mlir/Dialect/PDL/IR/PDL.h" 14 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 15 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16 #include "mlir/Parser/Parser.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 using namespace mlir; 20 using namespace mlir::linalg; 21 using namespace mlir::transform; 22 23 /// Extracts a vector of int64_t from an array attribute. Asserts if the 24 /// attribute contains values other than integers. 25 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) { 26 SmallVector<int64_t> result; 27 result.reserve(attr.size()); 28 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 29 result.push_back(value.getSExtValue()); 30 return result; 31 } 32 33 /// Extracts a vector of unsigned from an array attribute. Asserts if the 34 /// attribute contains values other than intergers. May truncate. 35 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) { 36 SmallVector<unsigned> result; 37 result.reserve(attr.size()); 38 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 39 result.push_back(value.getZExtValue()); 40 return result; 41 } 42 43 namespace { 44 /// A simple pattern rewriter that implements no special logic. 45 class SimpleRewriter : public PatternRewriter { 46 public: 47 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 48 }; 49 } // namespace 50 51 /// Attempts to apply the pattern specified as template argument to the given 52 /// operation. The pattern is expected to have a `returningMatchAndRewrite` 53 /// function that returns the "main" result or failure. Returns failure if the 54 /// pattern failed to apply. Extra arguments are forwarded to the pattern 55 /// constructor. 56 template <typename PatternTy, typename... Args> 57 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { 58 // Check if the given operation has the type expected by the pattern. 59 using OpTy = typename llvm::function_traits< 60 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; 61 auto op = dyn_cast<OpTy>(operation); 62 if (!op) 63 return failure(); 64 65 // Apply the pattern directly to the op. 66 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); 67 SimpleRewriter rewriter(operation->getContext()); 68 rewriter.setInsertionPoint(operation); 69 auto result = pattern.returningMatchAndRewrite(op, rewriter); 70 if (failed(result)) 71 return failure(); 72 return cast<LinalgOp>(result->getOperation()); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // DecomposeOp 77 //===----------------------------------------------------------------------===// 78 79 DiagnosedSilenceableFailure 80 transform::DecomposeOp::applyToOne(linalg::LinalgOp target, 81 SmallVectorImpl<Operation *> &results, 82 transform::TransformState &state) { 83 FailureOr<LinalgOp> windowed = 84 tryApply<DownscaleSizeOneWindowed2DConvolution>(target); 85 if (succeeded(windowed)) { 86 results.push_back(*windowed); 87 return DiagnosedSilenceableFailure(success()); 88 } 89 FailureOr<LinalgOp> depthwise = 90 tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target); 91 if (succeeded(depthwise)) { 92 results.push_back(*depthwise); 93 return DiagnosedSilenceableFailure(success()); 94 } 95 results.assign(1, nullptr); 96 return emitDefaultSilenceableFailure(target); 97 } 98 99 //===----------------------------------------------------------------------===// 100 // FuseOp 101 //===----------------------------------------------------------------------===// 102 103 /// Apply a tiling transformation to all payload ops and store both the 104 /// tiled operation as well as the created tile loops. 105 static LogicalResult 106 applyTilingToAll(Operation *transformOp, Value target, 107 ArrayRef<int64_t> tileSizes, 108 transform::TransformResults &transformResults, 109 transform::TransformState &state, 110 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) { 111 // Number of loops: Number of tiles sizes that are not zero. 112 size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0); 113 // All payload ops. These should all be LinalgOps for now. 114 ArrayRef<Operation *> payloadOps = state.getPayloadOps(target); 115 116 SmallVector<Operation *> tiledLinalgOps; 117 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 118 for (unsigned int i = 0; i < numLoops; ++i) 119 loopOps[i].reserve(payloadOps.size()); 120 121 for (Operation *target : payloadOps) { 122 auto linalgOp = dyn_cast<linalg::LinalgOp>(target); 123 if (!linalgOp) 124 return transformOp->emitError("only LinalgOps are supported"); 125 126 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp); 127 if (failed(tiled)) 128 return failure(); 129 130 tiledLinalgOps.push_back(tiled->op); 131 if (tiled->loops.size() != numLoops) 132 // Not enough loops were generated. This usually means that the input size 133 // was smaller than the tiling size. 134 // TODO: LinalgTilingPattern should return failure(). 135 return failure(); 136 for (unsigned int i = 0; i < numLoops; ++i) 137 loopOps[i].push_back(tiled->loops[i]); 138 } 139 140 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 141 for (unsigned int i = 0; i < numLoops; ++i) 142 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 143 return success(); 144 } 145 146 /// Parse a tiling-like operation that returns the tiled op as well as the 147 /// created tile loops. The function counts the non-zero tile sizes to compute 148 /// the number of results. 149 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, 150 StringRef sizesAttrName) { 151 OpAsmParser::UnresolvedOperand targetOperand; 152 SMLoc opLoc = parser.getCurrentLocation(); 153 if (parser.parseOperand(targetOperand) || 154 parser.parseOptionalAttrDict(result.attributes)) 155 return failure(); 156 Attribute sizesAttr = result.attributes.get(sizesAttrName); 157 if (!sizesAttr) 158 return parser.emitError(opLoc) 159 << "expected '" << sizesAttrName << "' attribute"; 160 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); 161 if (!sizesArrayAttr) 162 return parser.emitError(opLoc) 163 << "'" << sizesAttrName << "' attribute must be an array"; 164 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); 165 size_t numExpectedLoops = 166 sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0); 167 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); 168 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) 169 return failure(); 170 return success(); 171 } 172 173 DiagnosedSilenceableFailure 174 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, 175 mlir::transform::TransformState &state) { 176 LinalgTilingAndFusionOptions fusionOptions; 177 fusionOptions.tileSizes = extractI64Array(getTileSizes()); 178 fusionOptions.tileInterchange = extractI64Array(getTileInterchange()); 179 180 LogicalResult result = applyTilingToAll( 181 getOperation(), getTarget(), fusionOptions.tileSizes, transformResults, 182 state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> { 183 LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); 184 SimpleRewriter rewriter(getContext()); 185 rewriter.setInsertionPoint(linalgOp); 186 FailureOr<TileLoopNest> tileLoopNest = 187 pattern.returningMatchAndRewrite(linalgOp, rewriter); 188 if (failed(tileLoopNest)) 189 return failure(); 190 191 TiledLinalgOp tiledLinalgOp; 192 tiledLinalgOp.op = tileLoopNest->getRootOp(); 193 tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(), 194 tileLoopNest->getLoopOps().end()}; 195 return tiledLinalgOp; 196 }); 197 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() 198 : DiagnosedSilenceableFailure::success(); 199 } 200 201 ParseResult transform::FuseOp::parse(OpAsmParser &parser, 202 OperationState &result) { 203 return parseTileLikeOp( 204 parser, result, 205 transform::FuseOp::getTileSizesAttrName(result.name).getValue()); 206 } 207 208 void transform::FuseOp::print(OpAsmPrinter &p) { 209 p << ' '; 210 p << getTarget(); 211 p.printOptionalAttrDict((*this)->getAttrs()); 212 } 213 214 LogicalResult transform::FuseOp::verify() { 215 SmallVector<int64_t> permutation = extractI64Array(getTileInterchange()); 216 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); 217 if (!std::is_permutation(sequence.begin(), sequence.end(), 218 permutation.begin(), permutation.end())) { 219 return emitOpError() << "expects interchange to be a permutation, found " 220 << getTileInterchange(); 221 } 222 return success(); 223 } 224 225 //===----------------------------------------------------------------------===// 226 // GeneralizeOp 227 //===----------------------------------------------------------------------===// 228 229 DiagnosedSilenceableFailure 230 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, 231 SmallVectorImpl<Operation *> &results, 232 transform::TransformState &state) { 233 // Exit early if no transformation is needed. 234 if (isa<GenericOp>(target)) { 235 results.push_back(target); 236 return DiagnosedSilenceableFailure(success()); 237 } 238 FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target); 239 if (succeeded(generic)) { 240 results.push_back(generic->getOperation()); 241 return DiagnosedSilenceableFailure(success()); 242 } 243 results.assign(1, nullptr); 244 return emitDefaultSilenceableFailure(target); 245 } 246 247 //===----------------------------------------------------------------------===// 248 // InterchangeOp 249 //===----------------------------------------------------------------------===// 250 251 DiagnosedSilenceableFailure 252 transform::InterchangeOp::applyToOne(linalg::GenericOp target, 253 SmallVectorImpl<Operation *> &results, 254 transform::TransformState &state) { 255 SmallVector<unsigned> interchangeVector = 256 extractUIntArray(getIteratorInterchange()); 257 // Exit early if no transformation is needed. 258 if (interchangeVector.empty()) { 259 results.push_back(target); 260 return DiagnosedSilenceableFailure(success()); 261 } 262 SimpleRewriter rewriter(target->getContext()); 263 FailureOr<GenericOp> res = 264 interchangeGenericOp(rewriter, target, interchangeVector); 265 if (failed(res)) 266 return DiagnosedSilenceableFailure::definiteFailure(); 267 results.push_back(res->getOperation()); 268 return DiagnosedSilenceableFailure(success()); 269 } 270 271 LogicalResult transform::InterchangeOp::verify() { 272 SmallVector<unsigned> permutation = 273 extractUIntArray(getIteratorInterchange()); 274 auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size())); 275 if (!std::is_permutation(sequence.begin(), sequence.end(), 276 permutation.begin(), permutation.end())) { 277 return emitOpError() 278 << "expects iterator_interchange to be a permutation, found " 279 << getIteratorInterchange(); 280 } 281 return success(); 282 } 283 284 //===---------------------------------------------------------------------===// 285 // PadOp 286 //===---------------------------------------------------------------------===// 287 288 DiagnosedSilenceableFailure 289 transform::PadOp::applyToOne(linalg::LinalgOp target, 290 SmallVectorImpl<Operation *> &results, 291 transform::TransformState &state) { 292 // Convert the integer packing flags to booleans. 293 SmallVector<bool> packPaddings; 294 for (int64_t packPadding : extractI64Array(getPackPaddings())) 295 packPaddings.push_back(static_cast<bool>(packPadding)); 296 297 // Convert the padding values to attributes. 298 SmallVector<Attribute> paddingValues; 299 for (auto const &it : 300 llvm::zip(getPaddingValues(), target->getOperandTypes())) { 301 Attribute attr = std::get<0>(it); 302 Type elementType = getElementTypeOrSelf(std::get<1>(it)); 303 // Try to parse string attributes to obtain an attribute of element type. 304 if (auto stringAttr = attr.dyn_cast<StringAttr>()) { 305 paddingValues.push_back( 306 parseAttribute(attr.cast<StringAttr>(), elementType)); 307 if (!paddingValues.back()) { 308 auto diag = this->emitOpError("expects a padding that parses to ") 309 << elementType << ", got " << std::get<0>(it); 310 diag.attachNote(target.getLoc()) << "when applied to this op"; 311 return DiagnosedSilenceableFailure::definiteFailure(); 312 } 313 continue; 314 } 315 // Otherwise, add the attribute directly. 316 if (attr.getType() != elementType) { 317 auto diag = this->emitOpError("expects a padding value of type ") 318 << elementType << ", got " << attr; 319 diag.attachNote(target.getLoc()) << "when applied to this op"; 320 return DiagnosedSilenceableFailure::definiteFailure(); 321 } 322 paddingValues.push_back(attr); 323 } 324 325 // Extract the transpose vectors. 326 SmallVector<SmallVector<int64_t>> transposePaddings; 327 for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>()) 328 transposePaddings.push_back( 329 extractI64Array(transposeVector.cast<ArrayAttr>())); 330 331 LinalgPaddingOptions paddingOptions; 332 paddingOptions.setPaddingValues(paddingValues); 333 paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions())); 334 paddingOptions.setPackPaddings(packPaddings); 335 paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings())); 336 paddingOptions.setTransposePaddings(transposePaddings); 337 338 FailureOr<LinalgOp> result = 339 tryApply<LinalgPaddingPattern>(target, paddingOptions); 340 if (succeeded(result)) { 341 results.push_back(result->getOperation()); 342 return DiagnosedSilenceableFailure(success()); 343 } 344 345 results.assign(1, nullptr); 346 return emitDefaultSilenceableFailure(target); 347 } 348 349 LogicalResult transform::PadOp::verify() { 350 SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings()); 351 if (any_of(packPaddings, [](int64_t packPadding) { 352 return packPadding != 0 && packPadding != 1; 353 })) { 354 return emitOpError() 355 << "expects pack_paddings to contain booleans (0/1), found " 356 << getPackPaddings(); 357 } 358 359 SmallVector<int64_t> paddingDimensions = 360 extractI64Array(getPaddingDimensions()); 361 if (any_of(paddingDimensions, 362 [](int64_t paddingDimension) { return paddingDimension < 0; })) { 363 return emitOpError() 364 << "expects padding_dimensions to contain positive integers, found " 365 << getPaddingDimensions(); 366 } 367 368 SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings()); 369 if (any_of(hoistPaddings, 370 [](int64_t hoistPadding) { return hoistPadding < 0; })) { 371 return emitOpError() 372 << "expects hoist_paddings to contain positive integers, found " 373 << getHoistPaddings(); 374 } 375 376 ArrayAttr transposes = getTransposePaddings(); 377 for (Attribute attr : transposes) { 378 SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr); 379 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); 380 if (!std::is_permutation(sequence.begin(), sequence.end(), 381 transpose.begin(), transpose.end())) { 382 return emitOpError() 383 << "expects transpose_paddings to be a permutation, found " 384 << attr; 385 } 386 } 387 return success(); 388 } 389 390 //===----------------------------------------------------------------------===// 391 // ScalarizeOp 392 //===----------------------------------------------------------------------===// 393 394 DiagnosedSilenceableFailure 395 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, 396 SmallVectorImpl<Operation *> &results, 397 transform::TransformState &state) { 398 LinalgTilingOptions tilingOptions; 399 tilingOptions.scalarizeDynamicDims(); 400 // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile 401 // sizes and asserts that it is not already set. 402 SmallVector<int64_t> emptyTileSizes; 403 LinalgTilingPattern pattern(getContext(), tilingOptions); 404 SimpleRewriter rewriter(getContext()); 405 rewriter.setInsertionPoint(target); 406 FailureOr<TiledLinalgOp> result = 407 pattern.returningMatchAndRewrite(target, rewriter); 408 if (failed(result)) 409 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 410 411 results.push_back(result->op); 412 return DiagnosedSilenceableFailure(success()); 413 } 414 415 //===----------------------------------------------------------------------===// 416 // SplitOp 417 //===----------------------------------------------------------------------===// 418 419 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, 420 TransformState &state) { 421 // Collect the dynamic split points if provided. 422 ArrayRef<Operation *> payload = state.getPayloadOps(getTarget()); 423 SimpleRewriter rewriter(getContext()); 424 SmallVector<OpFoldResult> splitPoints; 425 splitPoints.reserve(payload.size()); 426 if (getDynamicSplitPoint()) { 427 auto diag = DiagnosedSilenceableFailure::success(); 428 splitPoints = llvm::to_vector(llvm::map_range( 429 state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { 430 if (op->getNumResults() != 1 || 431 !op->getResult(0).getType().isIndex()) { 432 diag = emitSilenceableError() 433 << "expected dynamic split point handle to point to a " 434 "single-result index-typed op"; 435 diag.attachNote(op->getLoc()) << "dynamic split point"; 436 } 437 return OpFoldResult(op->getResult(0)); 438 })); 439 if (!diag.succeeded()) 440 return diag; 441 442 if (splitPoints.size() != payload.size()) { 443 emitError() << "expected the dynamic split point handle to point to as " 444 "many operations (" 445 << splitPoints.size() << ") as the target handle (" 446 << payload.size() << ")"; 447 return DiagnosedSilenceableFailure::definiteFailure(); 448 } 449 } else { 450 splitPoints.resize(payload.size(), 451 rewriter.getIndexAttr(getStaticSplitPoint())); 452 } 453 454 // Split each target operation. 455 SmallVector<Operation *> first, second; 456 for (const auto &pair : llvm::zip(payload, splitPoints)) { 457 Operation *target = std::get<0>(pair); 458 auto linalgOp = dyn_cast<LinalgOp>(target); 459 if (!linalgOp) { 460 auto diag = emitSilenceableError() << "only applies to structured ops"; 461 diag.attachNote(target->getLoc()) << "target op"; 462 return diag; 463 } 464 465 if (getDimension() >= linalgOp.getNumLoops()) { 466 auto diag = emitSilenceableError() << "dimension " << getDimension() 467 << " does not exist in target op"; 468 diag.attachNote(target->getLoc()) << "target op"; 469 return diag; 470 } 471 472 rewriter.setInsertionPoint(linalgOp); 473 std::tie(first.emplace_back(), second.emplace_back()) = 474 linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair)); 475 } 476 477 results.set(getFirst().cast<OpResult>(), first); 478 results.set(getSecond().cast<OpResult>(), second); 479 return DiagnosedSilenceableFailure::success(); 480 } 481 482 void SplitOp::getEffects( 483 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 484 // The target handle is consumed. 485 effects.emplace_back(MemoryEffects::Read::get(), getTarget(), 486 TransformMappingResource::get()); 487 effects.emplace_back(MemoryEffects::Free::get(), getTarget(), 488 TransformMappingResource::get()); 489 490 // The dynamic split point handle is not consumed. 491 if (getDynamicSplitPoint()) { 492 effects.emplace_back(MemoryEffects::Read::get(), getDynamicSplitPoint(), 493 TransformMappingResource::get()); 494 } 495 496 // The resulting handles are produced. 497 for (Value result : getResults()) { 498 effects.emplace_back(MemoryEffects::Allocate::get(), result, 499 TransformMappingResource::get()); 500 effects.emplace_back(MemoryEffects::Write::get(), result, 501 TransformMappingResource::get()); 502 } 503 504 effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); 505 effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); 506 } 507 508 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { 509 OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; 510 IntegerAttr staticSplitPoint; 511 auto pdlOperationType = 512 pdl::OperationType::get(parser.getBuilder().getContext()); 513 if (parser.parseOperand(target) || 514 parser.resolveOperand(target, pdlOperationType, result.operands) || 515 parser.parseKeyword("after")) 516 return failure(); 517 518 OptionalParseResult dynamicPointParseResult = 519 parser.parseOptionalOperand(dynamicSplitPoint); 520 if (!dynamicPointParseResult.hasValue()) { 521 int64_t staticSplitPointValue; 522 if (failed(parser.parseInteger(staticSplitPointValue))) 523 return failure(); 524 525 staticSplitPoint = 526 parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); 527 } else { 528 if (failed(*dynamicPointParseResult) || 529 parser.resolveOperand(dynamicSplitPoint, pdlOperationType, 530 result.operands)) { 531 return failure(); 532 } 533 534 staticSplitPoint = 535 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize); 536 } 537 538 result.addAttribute( 539 SplitOp::getStaticSplitPointAttrName(result.name).getValue(), 540 staticSplitPoint); 541 if (failed(parser.parseOptionalAttrDict(result.attributes))) 542 return failure(); 543 544 result.addTypes({pdlOperationType, pdlOperationType}); 545 return success(); 546 } 547 548 void SplitOp::print(OpAsmPrinter &printer) { 549 printer << " " << getTarget() << " after "; 550 int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint()); 551 if (staticSplitSize != ShapedType::kDynamicSize) 552 printer << staticSplitSize; 553 else 554 printer << getDynamicSplitPoint(); 555 printer << " "; 556 printer.printOptionalAttrDict(getOperation()->getAttrs(), 557 {getStaticSplitPointAttrName()}); 558 } 559 560 LogicalResult SplitOp::verify() { 561 if ((static_cast<int64_t>(getStaticSplitPoint()) != 562 ShapedType::kDynamicSize) ^ 563 (getDynamicSplitPoint() == nullptr)) { 564 return emitOpError() 565 << "expects either a dynamic or a static split point to be provided"; 566 } 567 return success(); 568 } 569 570 //===----------------------------------------------------------------------===// 571 // SplitReductionOp 572 //===----------------------------------------------------------------------===// 573 574 DiagnosedSilenceableFailure 575 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, 576 SmallVectorImpl<Operation *> &results, 577 transform::TransformState &state) { 578 ControlSplitReductionFn splitFn = [&](LinalgOp) { 579 return std::pair<int64_t, unsigned>(getSplitFactor(), 580 getInsertSplitDimension()); 581 }; 582 SimpleRewriter rewriter(getContext()); 583 rewriter.setInsertionPoint(target); 584 FailureOr<SplitReductionResult> splitResult = 585 (getUseScalingAlgorithm()) 586 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) 587 : splitReduction(rewriter, target, splitFn, getUseAlloc()); 588 if (failed(splitResult)) 589 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 590 591 results.push_back(splitResult->initOrAlloc); 592 results.push_back(splitResult->fillOp); 593 results.push_back(splitResult->splitLinalgOp); 594 results.push_back(splitResult->resultCombiningLinalgOp); 595 return DiagnosedSilenceableFailure(success()); 596 } 597 598 //===----------------------------------------------------------------------===// 599 // TileOp 600 //===----------------------------------------------------------------------===// 601 602 DiagnosedSilenceableFailure 603 transform::TileOp::apply(TransformResults &transformResults, 604 TransformState &state) { 605 LinalgTilingOptions tilingOptions; 606 SmallVector<int64_t> tileSizes = extractI64Array(getSizes()); 607 608 if (!tileSizes.empty()) 609 tilingOptions.setTileSizes(tileSizes); 610 tilingOptions.setInterchange(extractUIntArray(getInterchange())); 611 LinalgTilingPattern pattern(getContext(), tilingOptions); 612 613 LogicalResult result = applyTilingToAll( 614 getOperation(), getTarget(), tileSizes, transformResults, state, 615 [&](LinalgOp linalgOp) { 616 SimpleRewriter rewriter(linalgOp.getContext()); 617 return pattern.returningMatchAndRewrite(linalgOp, rewriter); 618 }); 619 return DiagnosedSilenceableFailure(result); 620 } 621 622 ParseResult transform::TileOp::parse(OpAsmParser &parser, 623 OperationState &result) { 624 return parseTileLikeOp(parser, result, 625 TileOp::getSizesAttrName(result.name).getValue()); 626 } 627 628 void TileOp::print(OpAsmPrinter &p) { 629 p << ' '; 630 p << getTarget(); 631 p.printOptionalAttrDict((*this)->getAttrs()); 632 } 633 634 //===----------------------------------------------------------------------===// 635 // VectorizeOp 636 //===----------------------------------------------------------------------===// 637 638 DiagnosedSilenceableFailure 639 transform::VectorizeOp::applyToOne(Operation *target, 640 SmallVectorImpl<Operation *> &results, 641 transform::TransformState &state) { 642 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 643 auto diag = this->emitOpError("requires isolated-from-above targets"); 644 diag.attachNote(target->getLoc()) << "non-isolated target"; 645 return DiagnosedSilenceableFailure::definiteFailure(); 646 } 647 648 MLIRContext *ctx = getContext(); 649 RewritePatternSet patterns(ctx); 650 patterns.add<LinalgVectorizationPattern>(ctx); 651 652 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 653 vector::populateVectorReductionToContractPatterns(patterns); 654 patterns.add<linalg::LinalgCopyVTRForwardingPattern, 655 linalg::LinalgCopyVTWForwardingPattern>(ctx, 656 /*benefit=*/2); 657 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); 658 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); 659 if (getVectorizePadding()) 660 linalg::populatePadOpVectorizationPatterns(patterns); 661 662 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) 663 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 664 665 results.push_back(target); 666 return DiagnosedSilenceableFailure(success()); 667 } 668 669 //===----------------------------------------------------------------------===// 670 // Transform op registration 671 //===----------------------------------------------------------------------===// 672 673 namespace { 674 /// Registers new ops and declares PDL as dependent dialect since the additional 675 /// ops are using PDL types for operands and results. 676 class LinalgTransformDialectExtension 677 : public transform::TransformDialectExtension< 678 LinalgTransformDialectExtension> { 679 public: 680 LinalgTransformDialectExtension() { 681 declareDependentDialect<pdl::PDLDialect>(); 682 declareDependentDialect<scf::SCFDialect>(); 683 declareDependentDialect<vector::VectorDialect>(); 684 registerTransformOps< 685 #define GET_OP_LIST 686 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 687 >(); 688 } 689 }; 690 } // namespace 691 692 #define GET_OP_CLASSES 693 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 694 695 void mlir::linalg::registerTransformDialectExtension( 696 DialectRegistry ®istry) { 697 registry.addExtensions<LinalgTransformDialectExtension>(); 698 } 699