1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/PDL/IR/PDL.h" 16 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 19 #include "mlir/Interfaces/TilingInterface.h" 20 #include "mlir/Parser/Parser.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 23 using namespace mlir; 24 using namespace mlir::linalg; 25 using namespace mlir::transform; 26 27 /// Extracts a vector of unsigned from an array attribute. Asserts if the 28 /// attribute contains values other than intergers. May truncate. 29 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) { 30 SmallVector<unsigned> result; 31 result.reserve(attr.size()); 32 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 33 result.push_back(value.getZExtValue()); 34 return result; 35 } 36 37 namespace { 38 /// A simple pattern rewriter that implements no special logic. 39 class SimpleRewriter : public PatternRewriter { 40 public: 41 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 42 }; 43 } // namespace 44 45 /// Attempts to apply the pattern specified as template argument to the given 46 /// operation. The pattern is expected to have a `returningMatchAndRewrite` 47 /// function that returns the "main" result or failure. Returns failure if the 48 /// pattern failed to apply. Extra arguments are forwarded to the pattern 49 /// constructor. 50 template <typename PatternTy, typename... Args> 51 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { 52 // Check if the given operation has the type expected by the pattern. 53 using OpTy = typename llvm::function_traits< 54 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; 55 auto op = dyn_cast<OpTy>(operation); 56 if (!op) 57 return failure(); 58 59 // Apply the pattern directly to the op. 60 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); 61 SimpleRewriter rewriter(operation->getContext()); 62 rewriter.setInsertionPoint(operation); 63 auto result = pattern.returningMatchAndRewrite(op, rewriter); 64 if (failed(result)) 65 return failure(); 66 return cast<LinalgOp>(result->getOperation()); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // DecomposeOp 71 //===----------------------------------------------------------------------===// 72 73 DiagnosedSilenceableFailure 74 transform::DecomposeOp::applyToOne(linalg::LinalgOp target, 75 SmallVectorImpl<Operation *> &results, 76 transform::TransformState &state) { 77 FailureOr<LinalgOp> windowed = 78 tryApply<DownscaleSizeOneWindowed2DConvolution>(target); 79 if (succeeded(windowed)) { 80 results.push_back(*windowed); 81 return DiagnosedSilenceableFailure(success()); 82 } 83 FailureOr<LinalgOp> depthwise = 84 tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target); 85 if (succeeded(depthwise)) { 86 results.push_back(*depthwise); 87 return DiagnosedSilenceableFailure(success()); 88 } 89 results.assign(1, nullptr); 90 return emitDefaultSilenceableFailure(target); 91 } 92 93 //===----------------------------------------------------------------------===// 94 // FuseOp 95 //===----------------------------------------------------------------------===// 96 97 /// Apply a tiling transformation to all payload ops and store both the 98 /// tiled operation as well as the created tile loops. 99 static LogicalResult 100 applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps, 101 unsigned numLoops, 102 transform::TransformResults &transformResults, 103 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) { 104 SmallVector<Operation *> tiledLinalgOps; 105 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 106 for (unsigned int i = 0; i < numLoops; ++i) 107 loopOps[i].reserve(payloadOps.size()); 108 109 for (Operation *target : payloadOps) { 110 auto linalgOp = dyn_cast<linalg::LinalgOp>(target); 111 if (!linalgOp) 112 return transformOp->emitError("only LinalgOps are supported"); 113 114 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp); 115 if (failed(tiled)) 116 return failure(); 117 118 tiledLinalgOps.push_back(tiled->op); 119 if (tiled->loops.size() != numLoops) 120 // Not enough loops were generated. This usually means that the input size 121 // was smaller than the tiling size. 122 // TODO: LinalgTilingPattern should return failure(). 123 return failure(); 124 for (unsigned int i = 0; i < numLoops; ++i) 125 loopOps[i].push_back(tiled->loops[i]); 126 } 127 128 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 129 for (unsigned int i = 0; i < numLoops; ++i) 130 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 131 return success(); 132 } 133 134 /// Parse a tiling-like operation that returns the tiled op as well as the 135 /// created tile loops. The function counts the non-zero tile sizes to compute 136 /// the number of results. 137 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, 138 StringRef sizesAttrName) { 139 OpAsmParser::UnresolvedOperand targetOperand; 140 SMLoc opLoc = parser.getCurrentLocation(); 141 if (parser.parseOperand(targetOperand) || 142 parser.parseOptionalAttrDict(result.attributes)) 143 return failure(); 144 Attribute sizesAttr = result.attributes.get(sizesAttrName); 145 if (!sizesAttr) 146 return parser.emitError(opLoc) 147 << "expected '" << sizesAttrName << "' attribute"; 148 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); 149 if (!sizesArrayAttr) 150 return parser.emitError(opLoc) 151 << "'" << sizesAttrName << "' attribute must be an array"; 152 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); 153 size_t numExpectedLoops = 154 sizesArrayAttr.size() - 155 llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0); 156 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); 157 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) 158 return failure(); 159 return success(); 160 } 161 162 DiagnosedSilenceableFailure 163 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, 164 mlir::transform::TransformState &state) { 165 LinalgTilingAndFusionOptions fusionOptions; 166 fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes()); 167 fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange()); 168 169 LogicalResult result = applyTilingToAll( 170 getOperation(), state.getPayloadOps(getTarget()), 171 fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0), 172 transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> { 173 LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); 174 SimpleRewriter rewriter(getContext()); 175 rewriter.setInsertionPoint(linalgOp); 176 FailureOr<TileLoopNest> tileLoopNest = 177 pattern.returningMatchAndRewrite(linalgOp, rewriter); 178 if (failed(tileLoopNest)) 179 return failure(); 180 181 TiledLinalgOp tiledLinalgOp; 182 tiledLinalgOp.op = tileLoopNest->getRootOp(); 183 tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(), 184 tileLoopNest->getLoopOps().end()}; 185 return tiledLinalgOp; 186 }); 187 return DiagnosedSilenceableFailure(result); 188 } 189 190 ParseResult transform::FuseOp::parse(OpAsmParser &parser, 191 OperationState &result) { 192 return parseTileLikeOp( 193 parser, result, 194 transform::FuseOp::getTileSizesAttrName(result.name).getValue()); 195 } 196 197 void transform::FuseOp::print(OpAsmPrinter &p) { 198 p << ' '; 199 p << getTarget(); 200 p.printOptionalAttrDict((*this)->getAttrs()); 201 } 202 203 LogicalResult transform::FuseOp::verify() { 204 SmallVector<int64_t> permutation = 205 extractFromI64ArrayAttr(getTileInterchange()); 206 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); 207 if (!std::is_permutation(sequence.begin(), sequence.end(), 208 permutation.begin(), permutation.end())) { 209 return emitOpError() << "expects interchange to be a permutation, found " 210 << getTileInterchange(); 211 } 212 return success(); 213 } 214 215 //===----------------------------------------------------------------------===// 216 // GeneralizeOp 217 //===----------------------------------------------------------------------===// 218 219 DiagnosedSilenceableFailure 220 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, 221 SmallVectorImpl<Operation *> &results, 222 transform::TransformState &state) { 223 // Exit early if no transformation is needed. 224 if (isa<GenericOp>(target)) { 225 results.push_back(target); 226 return DiagnosedSilenceableFailure(success()); 227 } 228 FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target); 229 if (succeeded(generic)) { 230 results.push_back(generic->getOperation()); 231 return DiagnosedSilenceableFailure(success()); 232 } 233 results.assign(1, nullptr); 234 return emitDefaultSilenceableFailure(target); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // InterchangeOp 239 //===----------------------------------------------------------------------===// 240 241 DiagnosedSilenceableFailure 242 transform::InterchangeOp::applyToOne(linalg::GenericOp target, 243 SmallVectorImpl<Operation *> &results, 244 transform::TransformState &state) { 245 SmallVector<unsigned> interchangeVector = 246 extractUIntArray(getIteratorInterchange()); 247 // Exit early if no transformation is needed. 248 if (interchangeVector.empty()) { 249 results.push_back(target); 250 return DiagnosedSilenceableFailure(success()); 251 } 252 SimpleRewriter rewriter(target->getContext()); 253 FailureOr<GenericOp> res = 254 interchangeGenericOp(rewriter, target, interchangeVector); 255 if (failed(res)) 256 return DiagnosedSilenceableFailure::definiteFailure(); 257 results.push_back(res->getOperation()); 258 return DiagnosedSilenceableFailure(success()); 259 } 260 261 LogicalResult transform::InterchangeOp::verify() { 262 SmallVector<unsigned> permutation = 263 extractUIntArray(getIteratorInterchange()); 264 auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size())); 265 if (!std::is_permutation(sequence.begin(), sequence.end(), 266 permutation.begin(), permutation.end())) { 267 return emitOpError() 268 << "expects iterator_interchange to be a permutation, found " 269 << getIteratorInterchange(); 270 } 271 return success(); 272 } 273 274 //===---------------------------------------------------------------------===// 275 // MultiTileSizesOp 276 //===---------------------------------------------------------------------===// 277 278 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( 279 LinalgOp target, SmallVector<Operation *> &results, TransformState &state) { 280 OpBuilder builder(target.getContext()); 281 builder.setInsertionPoint(target); 282 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); 283 OpFoldResult divisor = builder.getIndexAttr(getDivisor()); 284 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes( 285 builder, target, getDimension(), targetSize, divisor); 286 if (failed(spec)) { 287 return emitSilenceableError() << "could not generate tile size computation"; 288 } 289 290 AffineExpr s0 = builder.getAffineSymbolExpr(0); 291 AffineExpr s1 = builder.getAffineSymbolExpr(1); 292 Operation *splitPoint = 293 makeComposedAffineApply(builder, target.getLoc(), s0 * s1, 294 {spec->lowTileSize, spec->lowTripCount}); 295 Operation *lowTileSize = spec->lowTileSize.getDefiningOp(); 296 Operation *highTileSize = spec->highTileSize.getDefiningOp(); 297 assert(lowTileSize && highTileSize && splitPoint && 298 "tile sizes are not produced by operations"); 299 results.reserve(results.size() + 3); 300 results.push_back(lowTileSize); 301 results.push_back(highTileSize); 302 results.push_back(splitPoint); 303 return DiagnosedSilenceableFailure::success(); 304 } 305 306 void transform::MultiTileSizesOp::getEffects( 307 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 308 onlyReadsHandle(getTarget(), effects); 309 producesHandle(getResults(), effects); 310 modifiesPayload(effects); 311 } 312 313 //===---------------------------------------------------------------------===// 314 // PadOp 315 //===---------------------------------------------------------------------===// 316 317 DiagnosedSilenceableFailure 318 transform::PadOp::applyToOne(linalg::LinalgOp target, 319 SmallVectorImpl<Operation *> &results, 320 transform::TransformState &state) { 321 // Convert the integer packing flags to booleans. 322 SmallVector<bool> packPaddings; 323 for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings())) 324 packPaddings.push_back(static_cast<bool>(packPadding)); 325 326 // Convert the padding values to attributes. 327 SmallVector<Attribute> paddingValues; 328 for (auto const &it : 329 llvm::zip(getPaddingValues(), target->getOperandTypes())) { 330 Attribute attr = std::get<0>(it); 331 Type elementType = getElementTypeOrSelf(std::get<1>(it)); 332 // Try to parse string attributes to obtain an attribute of element type. 333 if (auto stringAttr = attr.dyn_cast<StringAttr>()) { 334 paddingValues.push_back( 335 parseAttribute(attr.cast<StringAttr>(), elementType)); 336 if (!paddingValues.back()) { 337 auto diag = this->emitOpError("expects a padding that parses to ") 338 << elementType << ", got " << std::get<0>(it); 339 diag.attachNote(target.getLoc()) << "when applied to this op"; 340 return DiagnosedSilenceableFailure::definiteFailure(); 341 } 342 continue; 343 } 344 // Otherwise, add the attribute directly. 345 if (attr.getType() != elementType) { 346 auto diag = this->emitOpError("expects a padding value of type ") 347 << elementType << ", got " << attr; 348 diag.attachNote(target.getLoc()) << "when applied to this op"; 349 return DiagnosedSilenceableFailure::definiteFailure(); 350 } 351 paddingValues.push_back(attr); 352 } 353 354 // Extract the transpose vectors. 355 SmallVector<SmallVector<int64_t>> transposePaddings; 356 for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>()) 357 transposePaddings.push_back( 358 extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>())); 359 360 LinalgPaddingOptions paddingOptions; 361 paddingOptions.setPaddingValues(paddingValues); 362 paddingOptions.setPaddingDimensions( 363 extractFromI64ArrayAttr(getPaddingDimensions())); 364 paddingOptions.setPackPaddings(packPaddings); 365 paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings())); 366 paddingOptions.setTransposePaddings(transposePaddings); 367 368 FailureOr<LinalgOp> result = 369 tryApply<LinalgPaddingPattern>(target, paddingOptions); 370 if (succeeded(result)) { 371 results.push_back(result->getOperation()); 372 return DiagnosedSilenceableFailure(success()); 373 } 374 375 results.assign(1, nullptr); 376 return emitDefaultSilenceableFailure(target); 377 } 378 379 LogicalResult transform::PadOp::verify() { 380 SmallVector<int64_t> packPaddings = 381 extractFromI64ArrayAttr(getPackPaddings()); 382 if (any_of(packPaddings, [](int64_t packPadding) { 383 return packPadding != 0 && packPadding != 1; 384 })) { 385 return emitOpError() 386 << "expects pack_paddings to contain booleans (0/1), found " 387 << getPackPaddings(); 388 } 389 390 SmallVector<int64_t> paddingDimensions = 391 extractFromI64ArrayAttr(getPaddingDimensions()); 392 if (any_of(paddingDimensions, 393 [](int64_t paddingDimension) { return paddingDimension < 0; })) { 394 return emitOpError() 395 << "expects padding_dimensions to contain positive integers, found " 396 << getPaddingDimensions(); 397 } 398 399 SmallVector<int64_t> hoistPaddings = 400 extractFromI64ArrayAttr(getHoistPaddings()); 401 if (any_of(hoistPaddings, 402 [](int64_t hoistPadding) { return hoistPadding < 0; })) { 403 return emitOpError() 404 << "expects hoist_paddings to contain positive integers, found " 405 << getHoistPaddings(); 406 } 407 408 ArrayAttr transposes = getTransposePaddings(); 409 for (Attribute attr : transposes) { 410 SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr); 411 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); 412 if (!std::is_permutation(sequence.begin(), sequence.end(), 413 transpose.begin(), transpose.end())) { 414 return emitOpError() 415 << "expects transpose_paddings to be a permutation, found " 416 << attr; 417 } 418 } 419 return success(); 420 } 421 422 //===----------------------------------------------------------------------===// 423 // PromoteOp 424 //===----------------------------------------------------------------------===// 425 426 DiagnosedSilenceableFailure 427 transform::PromoteOp::applyToOne(linalg::LinalgOp target, 428 SmallVectorImpl<Operation *> &results, 429 transform::TransformState &state) { 430 LinalgPromotionOptions promotionOptions; 431 if (!getOperandsToPromote().empty()) 432 promotionOptions = promotionOptions.setOperandsToPromote( 433 extractFromI64ArrayAttr(getOperandsToPromote())); 434 if (getUseFullTilesByDefault()) 435 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( 436 getUseFullTilesByDefault()); 437 if (getUseAlloca()) 438 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca()); 439 if (!getUseFullTileBuffers().empty()) 440 promotionOptions = promotionOptions.setUseFullTileBuffers( 441 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>())); 442 if (getAlignment().has_value()) 443 promotionOptions = promotionOptions.setAlignment(*getAlignment()); 444 445 if (failed(promoteSubviewsPrecondition(target, promotionOptions))) 446 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 447 448 SimpleRewriter rewriter(target->getContext()); 449 rewriter.setInsertionPoint(target); 450 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions); 451 if (failed(res)) 452 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 453 results.push_back(target); 454 return DiagnosedSilenceableFailure(success()); 455 } 456 457 //===----------------------------------------------------------------------===// 458 // ScalarizeOp 459 //===----------------------------------------------------------------------===// 460 461 DiagnosedSilenceableFailure 462 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, 463 SmallVectorImpl<Operation *> &results, 464 transform::TransformState &state) { 465 LinalgTilingOptions tilingOptions; 466 tilingOptions.scalarizeDynamicDims(); 467 // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile 468 // sizes and asserts that it is not already set. 469 SmallVector<int64_t> emptyTileSizes; 470 LinalgTilingPattern pattern(getContext(), tilingOptions); 471 SimpleRewriter rewriter(getContext()); 472 rewriter.setInsertionPoint(target); 473 FailureOr<TiledLinalgOp> result = 474 pattern.returningMatchAndRewrite(target, rewriter); 475 if (failed(result)) 476 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 477 478 results.push_back(result->op); 479 return DiagnosedSilenceableFailure(success()); 480 } 481 482 //===----------------------------------------------------------------------===// 483 // SplitOp 484 //===----------------------------------------------------------------------===// 485 486 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, 487 TransformState &state) { 488 // Collect the dynamic split points if provided. 489 ArrayRef<Operation *> payload = state.getPayloadOps(getTarget()); 490 SimpleRewriter rewriter(getContext()); 491 SmallVector<OpFoldResult> splitPoints; 492 splitPoints.reserve(payload.size()); 493 if (getDynamicSplitPoint()) { 494 auto diag = DiagnosedSilenceableFailure::success(); 495 splitPoints = llvm::to_vector(llvm::map_range( 496 state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { 497 if (op->getNumResults() != 1 || 498 !op->getResult(0).getType().isIndex()) { 499 diag = emitSilenceableError() 500 << "expected dynamic split point handle to point to a " 501 "single-result index-typed op"; 502 diag.attachNote(op->getLoc()) << "dynamic split point"; 503 } 504 return OpFoldResult(op->getResult(0)); 505 })); 506 if (!diag.succeeded()) 507 return diag; 508 509 if (splitPoints.size() != payload.size()) { 510 emitError() << "expected the dynamic split point handle to point to as " 511 "many operations (" 512 << splitPoints.size() << ") as the target handle (" 513 << payload.size() << ")"; 514 return DiagnosedSilenceableFailure::definiteFailure(); 515 } 516 } else { 517 splitPoints.resize(payload.size(), 518 rewriter.getIndexAttr(getStaticSplitPoint())); 519 } 520 521 // Split each target operation. 522 SmallVector<Operation *> first, second; 523 for (const auto &pair : llvm::zip(payload, splitPoints)) { 524 Operation *target = std::get<0>(pair); 525 auto linalgOp = dyn_cast<LinalgOp>(target); 526 if (!linalgOp) { 527 auto diag = emitSilenceableError() << "only applies to structured ops"; 528 diag.attachNote(target->getLoc()) << "target op"; 529 return diag; 530 } 531 532 if (getDimension() >= linalgOp.getNumLoops()) { 533 auto diag = emitSilenceableError() << "dimension " << getDimension() 534 << " does not exist in target op"; 535 diag.attachNote(target->getLoc()) << "target op"; 536 return diag; 537 } 538 539 rewriter.setInsertionPoint(linalgOp); 540 std::tie(first.emplace_back(), second.emplace_back()) = 541 linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair)); 542 } 543 544 results.set(getFirst().cast<OpResult>(), first); 545 results.set(getSecond().cast<OpResult>(), second); 546 return DiagnosedSilenceableFailure::success(); 547 } 548 549 void SplitOp::getEffects( 550 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 551 consumesHandle(getTarget(), effects); 552 if (getDynamicSplitPoint()) 553 onlyReadsHandle(getDynamicSplitPoint(), effects); 554 producesHandle(getResults(), effects); 555 modifiesPayload(effects); 556 } 557 558 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { 559 OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; 560 IntegerAttr staticSplitPoint; 561 auto pdlOperationType = 562 pdl::OperationType::get(parser.getBuilder().getContext()); 563 if (parser.parseOperand(target) || 564 parser.resolveOperand(target, pdlOperationType, result.operands) || 565 parser.parseKeyword("after")) 566 return failure(); 567 568 OptionalParseResult dynamicPointParseResult = 569 parser.parseOptionalOperand(dynamicSplitPoint); 570 if (!dynamicPointParseResult.hasValue()) { 571 int64_t staticSplitPointValue; 572 if (failed(parser.parseInteger(staticSplitPointValue))) 573 return failure(); 574 575 staticSplitPoint = 576 parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); 577 } else { 578 if (failed(*dynamicPointParseResult) || 579 parser.resolveOperand(dynamicSplitPoint, pdlOperationType, 580 result.operands)) { 581 return failure(); 582 } 583 584 staticSplitPoint = 585 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize); 586 } 587 588 result.addAttribute( 589 SplitOp::getStaticSplitPointAttrName(result.name).getValue(), 590 staticSplitPoint); 591 if (failed(parser.parseOptionalAttrDict(result.attributes))) 592 return failure(); 593 594 result.addTypes({pdlOperationType, pdlOperationType}); 595 return success(); 596 } 597 598 void SplitOp::print(OpAsmPrinter &printer) { 599 printer << " " << getTarget() << " after "; 600 int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint()); 601 if (staticSplitSize != ShapedType::kDynamicSize) 602 printer << staticSplitSize; 603 else 604 printer << getDynamicSplitPoint(); 605 printer << " "; 606 printer.printOptionalAttrDict(getOperation()->getAttrs(), 607 {getStaticSplitPointAttrName()}); 608 } 609 610 LogicalResult SplitOp::verify() { 611 if ((static_cast<int64_t>(getStaticSplitPoint()) != 612 ShapedType::kDynamicSize) ^ 613 (getDynamicSplitPoint() == nullptr)) { 614 return emitOpError() 615 << "expects either a dynamic or a static split point to be provided"; 616 } 617 return success(); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // SplitReductionOp 622 //===----------------------------------------------------------------------===// 623 624 DiagnosedSilenceableFailure 625 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, 626 SmallVectorImpl<Operation *> &results, 627 transform::TransformState &state) { 628 ControlSplitReductionFn splitFn = [&](LinalgOp) { 629 return std::pair<int64_t, unsigned>(getSplitFactor(), 630 getInsertSplitDimension()); 631 }; 632 SimpleRewriter rewriter(getContext()); 633 rewriter.setInsertionPoint(target); 634 FailureOr<SplitReductionResult> splitResult = 635 (getUseScalingAlgorithm()) 636 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) 637 : splitReduction(rewriter, target, splitFn, getUseAlloc()); 638 if (failed(splitResult)) 639 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 640 641 results.push_back(splitResult->initOrAlloc); 642 results.push_back(splitResult->fillOp); 643 results.push_back(splitResult->splitLinalgOp); 644 results.push_back(splitResult->resultCombiningLinalgOp); 645 return DiagnosedSilenceableFailure(success()); 646 } 647 648 //===----------------------------------------------------------------------===// 649 // TileOp 650 //===----------------------------------------------------------------------===// 651 652 DiagnosedSilenceableFailure 653 transform::TileOp::apply(TransformResults &transformResults, 654 TransformState &state) { 655 LinalgTilingOptions tilingOptions; 656 SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes()); 657 658 ArrayRef<Operation *> targets = state.getPayloadOps(getTarget()); 659 SmallVector<ArrayRef<Operation *>> dynamicSizeProducers; 660 dynamicSizeProducers.reserve(getDynamicSizes().size()); 661 for (Value dynamicSizeProducerHandle : getDynamicSizes()) { 662 dynamicSizeProducers.push_back( 663 state.getPayloadOps(dynamicSizeProducerHandle)); 664 665 if (dynamicSizeProducers.back().size() != targets.size()) { 666 DiagnosedSilenceableFailure diag = 667 emitSilenceableError() 668 << "expected as many dynamic size-producing operations (" 669 << dynamicSizeProducers.back().size() << ") as target ops (" 670 << targets.size() << ")"; 671 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; 672 return diag; 673 } 674 675 for (Operation *op : dynamicSizeProducers.back()) { 676 if (op->getNumResults() == 1 && 677 op->getResult(0).getType().isa<IndexType>()) 678 continue; 679 DiagnosedSilenceableFailure diag = 680 emitSilenceableError() << "expected sizes to be produced by ops " 681 "with a single index-type result"; 682 diag.attachNote(op->getLoc()) << "size producer op"; 683 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; 684 return diag; 685 } 686 } 687 688 SmallVector<Operation *> tiled; 689 SmallVector<SmallVector<Operation *, 4>, 4> loops; 690 loops.resize(getLoops().size()); 691 for (auto &en : llvm::enumerate(targets)) { 692 auto linalgOp = dyn_cast<LinalgOp>(en.value()); 693 if (!linalgOp) { 694 DiagnosedSilenceableFailure diag = emitSilenceableError() 695 << "only linalg ops are supported"; 696 diag.attachNote(en.value()->getLoc()) << "target op"; 697 return diag; 698 } 699 700 unsigned index = en.index(); 701 if (!tileSizes.empty()) { 702 tilingOptions.setTileSizeComputationFunction( 703 [&, index](OpBuilder &b, Operation *) { 704 SmallVector<Value, 4> sizes; 705 sizes.reserve(tileSizes.size()); 706 unsigned dynamicIdx = 0; 707 for (OpFoldResult ofr : getMixedSizes()) { 708 if (auto attr = ofr.dyn_cast<Attribute>()) { 709 sizes.push_back(b.create<arith::ConstantIndexOp>( 710 getLoc(), attr.cast<IntegerAttr>().getInt())); 711 } else { 712 sizes.push_back( 713 dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); 714 } 715 } 716 return sizes; 717 }); 718 } 719 720 tilingOptions.setInterchange(extractUIntArray(getInterchange())); 721 LinalgTilingPattern pattern(getContext(), tilingOptions); 722 SimpleRewriter rewriter(linalgOp.getContext()); 723 FailureOr<TiledLinalgOp> tiledOp = 724 pattern.returningMatchAndRewrite(linalgOp, rewriter); 725 if (failed(tiledOp)) 726 return DiagnosedSilenceableFailure::definiteFailure(); 727 728 tiled.push_back(tiledOp->op); 729 for (const auto &en2 : llvm::enumerate(tiledOp->loops)) 730 loops[en2.index()].push_back(en2.value()); 731 } 732 733 transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled); 734 for (const auto &en : llvm::enumerate(loops)) 735 transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value()); 736 737 return DiagnosedSilenceableFailure::success(); 738 } 739 740 SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() { 741 ValueRange dynamic = getDynamicSizes(); 742 SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes()); 743 SmallVector<OpFoldResult> results; 744 results.reserve(tileSizes.size()); 745 unsigned dynamicPos = 0; 746 Builder builder(getContext()); 747 for (int64_t size : tileSizes) { 748 if (size == ShapedType::kDynamicSize) { 749 results.push_back(dynamic[dynamicPos++]); 750 } else { 751 results.push_back(builder.getIndexAttr(size)); 752 } 753 } 754 return results; 755 } 756 757 ParseResult transform::TileOp::parse(OpAsmParser &parser, 758 OperationState &result) { 759 OpAsmParser::UnresolvedOperand target; 760 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes; 761 ArrayAttr staticSizes; 762 auto pdlOperationType = pdl::OperationType::get(parser.getContext()); 763 if (parser.parseOperand(target) || 764 parser.resolveOperand(target, pdlOperationType, result.operands) || 765 parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) || 766 parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || 767 parser.parseOptionalAttrDict(result.attributes)) 768 return ParseResult::failure(); 769 770 result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); 771 size_t numExpectedLoops = 772 staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0); 773 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType)); 774 return success(); 775 } 776 777 void TileOp::print(OpAsmPrinter &p) { 778 p << ' ' << getTarget(); 779 printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(), 780 getStaticSizes()); 781 p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); 782 } 783 784 void transform::TileOp::getEffects( 785 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 786 consumesHandle(getTarget(), effects); 787 onlyReadsHandle(getDynamicSizes(), effects); 788 producesHandle(getTiledLinalgOp(), effects); 789 producesHandle(getLoops(), effects); 790 modifiesPayload(effects); 791 } 792 793 //===----------------------------------------------------------------------===// 794 // TileToForeachThreadOp 795 //===----------------------------------------------------------------------===// 796 797 DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne( 798 TilingInterface target, SmallVectorImpl<Operation *> &results, 799 transform::TransformState &state) { 800 IRRewriter rewriter(getContext()); 801 rewriter.setInsertionPoint(target); 802 auto maybeThreadDimMappingAttr = getThreadDimMapping(); 803 FailureOr<ForeachThreadTilingResult> tilingResult = 804 linalg::tileToForeachThreadOp( 805 rewriter, target, getAsOpFoldResult(getNumThreads()), 806 maybeThreadDimMappingAttr 807 ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) 808 : ArrayRef<int64_t>{}); 809 if (failed(tilingResult)) 810 return emitDefaultSilenceableFailure(target); 811 rewriter.replaceOp(target, tilingResult->tileOp->getResults()); 812 results.assign({tilingResult->tileOp, tilingResult->tiledOp}); 813 return DiagnosedSilenceableFailure(success()); 814 } 815 816 //===----------------------------------------------------------------------===// 817 // VectorizeOp 818 //===----------------------------------------------------------------------===// 819 820 DiagnosedSilenceableFailure 821 transform::VectorizeOp::applyToOne(Operation *target, 822 SmallVectorImpl<Operation *> &results, 823 transform::TransformState &state) { 824 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 825 auto diag = this->emitOpError("requires isolated-from-above targets"); 826 diag.attachNote(target->getLoc()) << "non-isolated target"; 827 return DiagnosedSilenceableFailure::definiteFailure(); 828 } 829 830 MLIRContext *ctx = getContext(); 831 RewritePatternSet patterns(ctx); 832 patterns.add<LinalgVectorizationPattern>(ctx); 833 834 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 835 vector::populateVectorReductionToContractPatterns(patterns); 836 patterns.add<linalg::LinalgCopyVTRForwardingPattern, 837 linalg::LinalgCopyVTWForwardingPattern>(ctx, 838 /*benefit=*/2); 839 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); 840 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); 841 if (getVectorizePadding()) 842 linalg::populatePadOpVectorizationPatterns(patterns); 843 844 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) 845 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 846 847 results.push_back(target); 848 return DiagnosedSilenceableFailure(success()); 849 } 850 851 //===----------------------------------------------------------------------===// 852 // Transform op registration 853 //===----------------------------------------------------------------------===// 854 855 namespace { 856 /// Registers new ops and declares PDL as dependent dialect since the additional 857 /// ops are using PDL types for operands and results. 858 class LinalgTransformDialectExtension 859 : public transform::TransformDialectExtension< 860 LinalgTransformDialectExtension> { 861 public: 862 LinalgTransformDialectExtension() { 863 declareDependentDialect<AffineDialect>(); 864 declareDependentDialect<arith::ArithmeticDialect>(); 865 declareDependentDialect<pdl::PDLDialect>(); 866 declareDependentDialect<scf::SCFDialect>(); 867 declareDependentDialect<vector::VectorDialect>(); 868 registerTransformOps< 869 #define GET_OP_LIST 870 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 871 >(); 872 } 873 }; 874 } // namespace 875 876 #define GET_OP_CLASSES 877 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 878 879 void mlir::linalg::registerTransformDialectExtension( 880 DialectRegistry ®istry) { 881 registry.addExtensions<LinalgTransformDialectExtension>(); 882 } 883