1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/PDL/IR/PDL.h" 16 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 19 #include "mlir/Parser/Parser.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::linalg; 24 using namespace mlir::transform; 25 26 /// Extracts a vector of int64_t from an array attribute. Asserts if the 27 /// attribute contains values other than integers. 28 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) { 29 SmallVector<int64_t> result; 30 result.reserve(attr.size()); 31 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 32 result.push_back(value.getSExtValue()); 33 return result; 34 } 35 36 /// Extracts a vector of unsigned from an array attribute. Asserts if the 37 /// attribute contains values other than intergers. May truncate. 38 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) { 39 SmallVector<unsigned> result; 40 result.reserve(attr.size()); 41 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 42 result.push_back(value.getZExtValue()); 43 return result; 44 } 45 46 namespace { 47 /// A simple pattern rewriter that implements no special logic. 48 class SimpleRewriter : public PatternRewriter { 49 public: 50 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 51 }; 52 } // namespace 53 54 /// Attempts to apply the pattern specified as template argument to the given 55 /// operation. The pattern is expected to have a `returningMatchAndRewrite` 56 /// function that returns the "main" result or failure. Returns failure if the 57 /// pattern failed to apply. Extra arguments are forwarded to the pattern 58 /// constructor. 59 template <typename PatternTy, typename... Args> 60 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { 61 // Check if the given operation has the type expected by the pattern. 62 using OpTy = typename llvm::function_traits< 63 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; 64 auto op = dyn_cast<OpTy>(operation); 65 if (!op) 66 return failure(); 67 68 // Apply the pattern directly to the op. 69 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); 70 SimpleRewriter rewriter(operation->getContext()); 71 rewriter.setInsertionPoint(operation); 72 auto result = pattern.returningMatchAndRewrite(op, rewriter); 73 if (failed(result)) 74 return failure(); 75 return cast<LinalgOp>(result->getOperation()); 76 } 77 78 //===----------------------------------------------------------------------===// 79 // DecomposeOp 80 //===----------------------------------------------------------------------===// 81 82 DiagnosedSilenceableFailure 83 transform::DecomposeOp::applyToOne(linalg::LinalgOp target, 84 SmallVectorImpl<Operation *> &results, 85 transform::TransformState &state) { 86 FailureOr<LinalgOp> windowed = 87 tryApply<DownscaleSizeOneWindowed2DConvolution>(target); 88 if (succeeded(windowed)) { 89 results.push_back(*windowed); 90 return DiagnosedSilenceableFailure(success()); 91 } 92 FailureOr<LinalgOp> depthwise = 93 tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target); 94 if (succeeded(depthwise)) { 95 results.push_back(*depthwise); 96 return DiagnosedSilenceableFailure(success()); 97 } 98 results.assign(1, nullptr); 99 return emitDefaultSilenceableFailure(target); 100 } 101 102 //===----------------------------------------------------------------------===// 103 // FuseOp 104 //===----------------------------------------------------------------------===// 105 106 /// Apply a tiling transformation to all payload ops and store both the 107 /// tiled operation as well as the created tile loops. 108 static LogicalResult 109 applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps, 110 unsigned numLoops, 111 transform::TransformResults &transformResults, 112 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) { 113 SmallVector<Operation *> tiledLinalgOps; 114 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 115 for (unsigned int i = 0; i < numLoops; ++i) 116 loopOps[i].reserve(payloadOps.size()); 117 118 for (Operation *target : payloadOps) { 119 auto linalgOp = dyn_cast<linalg::LinalgOp>(target); 120 if (!linalgOp) 121 return transformOp->emitError("only LinalgOps are supported"); 122 123 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp); 124 if (failed(tiled)) 125 return failure(); 126 127 tiledLinalgOps.push_back(tiled->op); 128 if (tiled->loops.size() != numLoops) 129 // Not enough loops were generated. This usually means that the input size 130 // was smaller than the tiling size. 131 // TODO: LinalgTilingPattern should return failure(). 132 return failure(); 133 for (unsigned int i = 0; i < numLoops; ++i) 134 loopOps[i].push_back(tiled->loops[i]); 135 } 136 137 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 138 for (unsigned int i = 0; i < numLoops; ++i) 139 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 140 return success(); 141 } 142 143 /// Parse a tiling-like operation that returns the tiled op as well as the 144 /// created tile loops. The function counts the non-zero tile sizes to compute 145 /// the number of results. 146 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, 147 StringRef sizesAttrName) { 148 OpAsmParser::UnresolvedOperand targetOperand; 149 SMLoc opLoc = parser.getCurrentLocation(); 150 if (parser.parseOperand(targetOperand) || 151 parser.parseOptionalAttrDict(result.attributes)) 152 return failure(); 153 Attribute sizesAttr = result.attributes.get(sizesAttrName); 154 if (!sizesAttr) 155 return parser.emitError(opLoc) 156 << "expected '" << sizesAttrName << "' attribute"; 157 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); 158 if (!sizesArrayAttr) 159 return parser.emitError(opLoc) 160 << "'" << sizesAttrName << "' attribute must be an array"; 161 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); 162 size_t numExpectedLoops = 163 sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0); 164 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); 165 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) 166 return failure(); 167 return success(); 168 } 169 170 DiagnosedSilenceableFailure 171 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, 172 mlir::transform::TransformState &state) { 173 LinalgTilingAndFusionOptions fusionOptions; 174 fusionOptions.tileSizes = extractI64Array(getTileSizes()); 175 fusionOptions.tileInterchange = extractI64Array(getTileInterchange()); 176 177 LogicalResult result = applyTilingToAll( 178 getOperation(), state.getPayloadOps(getTarget()), 179 fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0), 180 transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> { 181 LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); 182 SimpleRewriter rewriter(getContext()); 183 rewriter.setInsertionPoint(linalgOp); 184 FailureOr<TileLoopNest> tileLoopNest = 185 pattern.returningMatchAndRewrite(linalgOp, rewriter); 186 if (failed(tileLoopNest)) 187 return failure(); 188 189 TiledLinalgOp tiledLinalgOp; 190 tiledLinalgOp.op = tileLoopNest->getRootOp(); 191 tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(), 192 tileLoopNest->getLoopOps().end()}; 193 return tiledLinalgOp; 194 }); 195 return DiagnosedSilenceableFailure(result); 196 } 197 198 ParseResult transform::FuseOp::parse(OpAsmParser &parser, 199 OperationState &result) { 200 return parseTileLikeOp( 201 parser, result, 202 transform::FuseOp::getTileSizesAttrName(result.name).getValue()); 203 } 204 205 void transform::FuseOp::print(OpAsmPrinter &p) { 206 p << ' '; 207 p << getTarget(); 208 p.printOptionalAttrDict((*this)->getAttrs()); 209 } 210 211 LogicalResult transform::FuseOp::verify() { 212 SmallVector<int64_t> permutation = extractI64Array(getTileInterchange()); 213 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); 214 if (!std::is_permutation(sequence.begin(), sequence.end(), 215 permutation.begin(), permutation.end())) { 216 return emitOpError() << "expects interchange to be a permutation, found " 217 << getTileInterchange(); 218 } 219 return success(); 220 } 221 222 //===----------------------------------------------------------------------===// 223 // GeneralizeOp 224 //===----------------------------------------------------------------------===// 225 226 DiagnosedSilenceableFailure 227 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, 228 SmallVectorImpl<Operation *> &results, 229 transform::TransformState &state) { 230 // Exit early if no transformation is needed. 231 if (isa<GenericOp>(target)) { 232 results.push_back(target); 233 return DiagnosedSilenceableFailure(success()); 234 } 235 FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target); 236 if (succeeded(generic)) { 237 results.push_back(generic->getOperation()); 238 return DiagnosedSilenceableFailure(success()); 239 } 240 results.assign(1, nullptr); 241 return emitDefaultSilenceableFailure(target); 242 } 243 244 //===----------------------------------------------------------------------===// 245 // InterchangeOp 246 //===----------------------------------------------------------------------===// 247 248 DiagnosedSilenceableFailure 249 transform::InterchangeOp::applyToOne(linalg::GenericOp target, 250 SmallVectorImpl<Operation *> &results, 251 transform::TransformState &state) { 252 SmallVector<unsigned> interchangeVector = 253 extractUIntArray(getIteratorInterchange()); 254 // Exit early if no transformation is needed. 255 if (interchangeVector.empty()) { 256 results.push_back(target); 257 return DiagnosedSilenceableFailure(success()); 258 } 259 SimpleRewriter rewriter(target->getContext()); 260 FailureOr<GenericOp> res = 261 interchangeGenericOp(rewriter, target, interchangeVector); 262 if (failed(res)) 263 return DiagnosedSilenceableFailure::definiteFailure(); 264 results.push_back(res->getOperation()); 265 return DiagnosedSilenceableFailure(success()); 266 } 267 268 LogicalResult transform::InterchangeOp::verify() { 269 SmallVector<unsigned> permutation = 270 extractUIntArray(getIteratorInterchange()); 271 auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size())); 272 if (!std::is_permutation(sequence.begin(), sequence.end(), 273 permutation.begin(), permutation.end())) { 274 return emitOpError() 275 << "expects iterator_interchange to be a permutation, found " 276 << getIteratorInterchange(); 277 } 278 return success(); 279 } 280 281 //===---------------------------------------------------------------------===// 282 // MultiTileSizesOp 283 //===---------------------------------------------------------------------===// 284 285 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( 286 LinalgOp target, SmallVector<Operation *> &results, TransformState &state) { 287 OpBuilder builder(target.getContext()); 288 builder.setInsertionPoint(target); 289 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); 290 OpFoldResult divisor = builder.getIndexAttr(getDivisor()); 291 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes( 292 builder, target, getDimension(), targetSize, divisor); 293 if (failed(spec)) { 294 return emitSilenceableError() << "could not generate tile size computation"; 295 } 296 297 AffineExpr s0 = builder.getAffineSymbolExpr(0); 298 AffineExpr s1 = builder.getAffineSymbolExpr(1); 299 Operation *splitPoint = 300 makeComposedAffineApply(builder, target.getLoc(), s0 * s1, 301 {spec->lowTileSize, spec->lowTripCount}); 302 Operation *lowTileSize = spec->lowTileSize.getDefiningOp(); 303 Operation *highTileSize = spec->highTileSize.getDefiningOp(); 304 assert(lowTileSize && highTileSize && splitPoint && 305 "tile sizes are not produced by operations"); 306 results.reserve(results.size() + 3); 307 results.push_back(lowTileSize); 308 results.push_back(highTileSize); 309 results.push_back(splitPoint); 310 return DiagnosedSilenceableFailure::success(); 311 } 312 313 void transform::MultiTileSizesOp::getEffects( 314 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 315 onlyReadsHandle(getTarget(), effects); 316 producesHandle(getResults(), effects); 317 modifiesPayload(effects); 318 } 319 320 //===---------------------------------------------------------------------===// 321 // PadOp 322 //===---------------------------------------------------------------------===// 323 324 DiagnosedSilenceableFailure 325 transform::PadOp::applyToOne(linalg::LinalgOp target, 326 SmallVectorImpl<Operation *> &results, 327 transform::TransformState &state) { 328 // Convert the integer packing flags to booleans. 329 SmallVector<bool> packPaddings; 330 for (int64_t packPadding : extractI64Array(getPackPaddings())) 331 packPaddings.push_back(static_cast<bool>(packPadding)); 332 333 // Convert the padding values to attributes. 334 SmallVector<Attribute> paddingValues; 335 for (auto const &it : 336 llvm::zip(getPaddingValues(), target->getOperandTypes())) { 337 Attribute attr = std::get<0>(it); 338 Type elementType = getElementTypeOrSelf(std::get<1>(it)); 339 // Try to parse string attributes to obtain an attribute of element type. 340 if (auto stringAttr = attr.dyn_cast<StringAttr>()) { 341 paddingValues.push_back( 342 parseAttribute(attr.cast<StringAttr>(), elementType)); 343 if (!paddingValues.back()) { 344 auto diag = this->emitOpError("expects a padding that parses to ") 345 << elementType << ", got " << std::get<0>(it); 346 diag.attachNote(target.getLoc()) << "when applied to this op"; 347 return DiagnosedSilenceableFailure::definiteFailure(); 348 } 349 continue; 350 } 351 // Otherwise, add the attribute directly. 352 if (attr.getType() != elementType) { 353 auto diag = this->emitOpError("expects a padding value of type ") 354 << elementType << ", got " << attr; 355 diag.attachNote(target.getLoc()) << "when applied to this op"; 356 return DiagnosedSilenceableFailure::definiteFailure(); 357 } 358 paddingValues.push_back(attr); 359 } 360 361 // Extract the transpose vectors. 362 SmallVector<SmallVector<int64_t>> transposePaddings; 363 for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>()) 364 transposePaddings.push_back( 365 extractI64Array(transposeVector.cast<ArrayAttr>())); 366 367 LinalgPaddingOptions paddingOptions; 368 paddingOptions.setPaddingValues(paddingValues); 369 paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions())); 370 paddingOptions.setPackPaddings(packPaddings); 371 paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings())); 372 paddingOptions.setTransposePaddings(transposePaddings); 373 374 FailureOr<LinalgOp> result = 375 tryApply<LinalgPaddingPattern>(target, paddingOptions); 376 if (succeeded(result)) { 377 results.push_back(result->getOperation()); 378 return DiagnosedSilenceableFailure(success()); 379 } 380 381 results.assign(1, nullptr); 382 return emitDefaultSilenceableFailure(target); 383 } 384 385 LogicalResult transform::PadOp::verify() { 386 SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings()); 387 if (any_of(packPaddings, [](int64_t packPadding) { 388 return packPadding != 0 && packPadding != 1; 389 })) { 390 return emitOpError() 391 << "expects pack_paddings to contain booleans (0/1), found " 392 << getPackPaddings(); 393 } 394 395 SmallVector<int64_t> paddingDimensions = 396 extractI64Array(getPaddingDimensions()); 397 if (any_of(paddingDimensions, 398 [](int64_t paddingDimension) { return paddingDimension < 0; })) { 399 return emitOpError() 400 << "expects padding_dimensions to contain positive integers, found " 401 << getPaddingDimensions(); 402 } 403 404 SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings()); 405 if (any_of(hoistPaddings, 406 [](int64_t hoistPadding) { return hoistPadding < 0; })) { 407 return emitOpError() 408 << "expects hoist_paddings to contain positive integers, found " 409 << getHoistPaddings(); 410 } 411 412 ArrayAttr transposes = getTransposePaddings(); 413 for (Attribute attr : transposes) { 414 SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr); 415 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); 416 if (!std::is_permutation(sequence.begin(), sequence.end(), 417 transpose.begin(), transpose.end())) { 418 return emitOpError() 419 << "expects transpose_paddings to be a permutation, found " 420 << attr; 421 } 422 } 423 return success(); 424 } 425 426 //===----------------------------------------------------------------------===// 427 // PromoteOp 428 //===----------------------------------------------------------------------===// 429 430 DiagnosedSilenceableFailure 431 transform::PromoteOp::applyToOne(linalg::LinalgOp target, 432 SmallVectorImpl<Operation *> &results, 433 transform::TransformState &state) { 434 LinalgPromotionOptions promotionOptions; 435 if (!getOperandsToPromote().empty()) 436 promotionOptions = promotionOptions.setOperandsToPromote( 437 extractFromI64ArrayAttr(getOperandsToPromote())); 438 if (getUseFullTilesByDefault()) 439 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( 440 getUseFullTilesByDefault()); 441 if (getUseAlloca()) 442 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca()); 443 if (!getUseFullTileBuffers().empty()) 444 promotionOptions = promotionOptions.setUseFullTileBuffers( 445 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>())); 446 if (getAlignment().hasValue()) 447 promotionOptions = promotionOptions.setAlignment(*getAlignment()); 448 449 if (failed(promoteSubviewsPrecondition(target, promotionOptions))) 450 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 451 452 SimpleRewriter rewriter(target->getContext()); 453 rewriter.setInsertionPoint(target); 454 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions); 455 if (failed(res)) 456 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 457 results.push_back(target); 458 return DiagnosedSilenceableFailure(success()); 459 } 460 461 //===----------------------------------------------------------------------===// 462 // ScalarizeOp 463 //===----------------------------------------------------------------------===// 464 465 DiagnosedSilenceableFailure 466 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, 467 SmallVectorImpl<Operation *> &results, 468 transform::TransformState &state) { 469 LinalgTilingOptions tilingOptions; 470 tilingOptions.scalarizeDynamicDims(); 471 // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile 472 // sizes and asserts that it is not already set. 473 SmallVector<int64_t> emptyTileSizes; 474 LinalgTilingPattern pattern(getContext(), tilingOptions); 475 SimpleRewriter rewriter(getContext()); 476 rewriter.setInsertionPoint(target); 477 FailureOr<TiledLinalgOp> result = 478 pattern.returningMatchAndRewrite(target, rewriter); 479 if (failed(result)) 480 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 481 482 results.push_back(result->op); 483 return DiagnosedSilenceableFailure(success()); 484 } 485 486 //===----------------------------------------------------------------------===// 487 // SplitOp 488 //===----------------------------------------------------------------------===// 489 490 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, 491 TransformState &state) { 492 // Collect the dynamic split points if provided. 493 ArrayRef<Operation *> payload = state.getPayloadOps(getTarget()); 494 SimpleRewriter rewriter(getContext()); 495 SmallVector<OpFoldResult> splitPoints; 496 splitPoints.reserve(payload.size()); 497 if (getDynamicSplitPoint()) { 498 auto diag = DiagnosedSilenceableFailure::success(); 499 splitPoints = llvm::to_vector(llvm::map_range( 500 state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { 501 if (op->getNumResults() != 1 || 502 !op->getResult(0).getType().isIndex()) { 503 diag = emitSilenceableError() 504 << "expected dynamic split point handle to point to a " 505 "single-result index-typed op"; 506 diag.attachNote(op->getLoc()) << "dynamic split point"; 507 } 508 return OpFoldResult(op->getResult(0)); 509 })); 510 if (!diag.succeeded()) 511 return diag; 512 513 if (splitPoints.size() != payload.size()) { 514 emitError() << "expected the dynamic split point handle to point to as " 515 "many operations (" 516 << splitPoints.size() << ") as the target handle (" 517 << payload.size() << ")"; 518 return DiagnosedSilenceableFailure::definiteFailure(); 519 } 520 } else { 521 splitPoints.resize(payload.size(), 522 rewriter.getIndexAttr(getStaticSplitPoint())); 523 } 524 525 // Split each target operation. 526 SmallVector<Operation *> first, second; 527 for (const auto &pair : llvm::zip(payload, splitPoints)) { 528 Operation *target = std::get<0>(pair); 529 auto linalgOp = dyn_cast<LinalgOp>(target); 530 if (!linalgOp) { 531 auto diag = emitSilenceableError() << "only applies to structured ops"; 532 diag.attachNote(target->getLoc()) << "target op"; 533 return diag; 534 } 535 536 if (getDimension() >= linalgOp.getNumLoops()) { 537 auto diag = emitSilenceableError() << "dimension " << getDimension() 538 << " does not exist in target op"; 539 diag.attachNote(target->getLoc()) << "target op"; 540 return diag; 541 } 542 543 rewriter.setInsertionPoint(linalgOp); 544 std::tie(first.emplace_back(), second.emplace_back()) = 545 linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair)); 546 } 547 548 results.set(getFirst().cast<OpResult>(), first); 549 results.set(getSecond().cast<OpResult>(), second); 550 return DiagnosedSilenceableFailure::success(); 551 } 552 553 void SplitOp::getEffects( 554 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 555 consumesHandle(getTarget(), effects); 556 if (getDynamicSplitPoint()) 557 onlyReadsHandle(getDynamicSplitPoint(), effects); 558 producesHandle(getResults(), effects); 559 modifiesPayload(effects); 560 } 561 562 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { 563 OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; 564 IntegerAttr staticSplitPoint; 565 auto pdlOperationType = 566 pdl::OperationType::get(parser.getBuilder().getContext()); 567 if (parser.parseOperand(target) || 568 parser.resolveOperand(target, pdlOperationType, result.operands) || 569 parser.parseKeyword("after")) 570 return failure(); 571 572 OptionalParseResult dynamicPointParseResult = 573 parser.parseOptionalOperand(dynamicSplitPoint); 574 if (!dynamicPointParseResult.hasValue()) { 575 int64_t staticSplitPointValue; 576 if (failed(parser.parseInteger(staticSplitPointValue))) 577 return failure(); 578 579 staticSplitPoint = 580 parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); 581 } else { 582 if (failed(*dynamicPointParseResult) || 583 parser.resolveOperand(dynamicSplitPoint, pdlOperationType, 584 result.operands)) { 585 return failure(); 586 } 587 588 staticSplitPoint = 589 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize); 590 } 591 592 result.addAttribute( 593 SplitOp::getStaticSplitPointAttrName(result.name).getValue(), 594 staticSplitPoint); 595 if (failed(parser.parseOptionalAttrDict(result.attributes))) 596 return failure(); 597 598 result.addTypes({pdlOperationType, pdlOperationType}); 599 return success(); 600 } 601 602 void SplitOp::print(OpAsmPrinter &printer) { 603 printer << " " << getTarget() << " after "; 604 int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint()); 605 if (staticSplitSize != ShapedType::kDynamicSize) 606 printer << staticSplitSize; 607 else 608 printer << getDynamicSplitPoint(); 609 printer << " "; 610 printer.printOptionalAttrDict(getOperation()->getAttrs(), 611 {getStaticSplitPointAttrName()}); 612 } 613 614 LogicalResult SplitOp::verify() { 615 if ((static_cast<int64_t>(getStaticSplitPoint()) != 616 ShapedType::kDynamicSize) ^ 617 (getDynamicSplitPoint() == nullptr)) { 618 return emitOpError() 619 << "expects either a dynamic or a static split point to be provided"; 620 } 621 return success(); 622 } 623 624 //===----------------------------------------------------------------------===// 625 // SplitReductionOp 626 //===----------------------------------------------------------------------===// 627 628 DiagnosedSilenceableFailure 629 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, 630 SmallVectorImpl<Operation *> &results, 631 transform::TransformState &state) { 632 ControlSplitReductionFn splitFn = [&](LinalgOp) { 633 return std::pair<int64_t, unsigned>(getSplitFactor(), 634 getInsertSplitDimension()); 635 }; 636 SimpleRewriter rewriter(getContext()); 637 rewriter.setInsertionPoint(target); 638 FailureOr<SplitReductionResult> splitResult = 639 (getUseScalingAlgorithm()) 640 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) 641 : splitReduction(rewriter, target, splitFn, getUseAlloc()); 642 if (failed(splitResult)) 643 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 644 645 results.push_back(splitResult->initOrAlloc); 646 results.push_back(splitResult->fillOp); 647 results.push_back(splitResult->splitLinalgOp); 648 results.push_back(splitResult->resultCombiningLinalgOp); 649 return DiagnosedSilenceableFailure(success()); 650 } 651 652 //===----------------------------------------------------------------------===// 653 // TileOp 654 //===----------------------------------------------------------------------===// 655 656 DiagnosedSilenceableFailure 657 transform::TileOp::apply(TransformResults &transformResults, 658 TransformState &state) { 659 LinalgTilingOptions tilingOptions; 660 SmallVector<int64_t> tileSizes = extractI64Array(getStaticSizes()); 661 662 ArrayRef<Operation *> targets = state.getPayloadOps(getTarget()); 663 SmallVector<ArrayRef<Operation *>> dynamicSizeProducers; 664 dynamicSizeProducers.reserve(getDynamicSizes().size()); 665 for (Value dynamicSizeProducerHandle : getDynamicSizes()) { 666 dynamicSizeProducers.push_back( 667 state.getPayloadOps(dynamicSizeProducerHandle)); 668 669 if (dynamicSizeProducers.back().size() != targets.size()) { 670 DiagnosedSilenceableFailure diag = 671 emitSilenceableError() 672 << "expected as many dynamic size-producing operations (" 673 << dynamicSizeProducers.back().size() << ") as target ops (" 674 << targets.size() << ")"; 675 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; 676 return diag; 677 } 678 679 for (Operation *op : dynamicSizeProducers.back()) { 680 if (op->getNumResults() == 1 && 681 op->getResult(0).getType().isa<IndexType>()) 682 continue; 683 DiagnosedSilenceableFailure diag = 684 emitSilenceableError() << "expected sizes to be produced by ops " 685 "with a single index-type result"; 686 diag.attachNote(op->getLoc()) << "size producer op"; 687 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; 688 return diag; 689 } 690 } 691 692 SmallVector<Operation *> tiled; 693 SmallVector<SmallVector<Operation *, 4>, 4> loops; 694 loops.resize(getLoops().size()); 695 for (auto &en : llvm::enumerate(targets)) { 696 auto linalgOp = dyn_cast<LinalgOp>(en.value()); 697 if (!linalgOp) { 698 DiagnosedSilenceableFailure diag = emitSilenceableError() 699 << "only linalg ops are supported"; 700 diag.attachNote(en.value()->getLoc()) << "target op"; 701 return diag; 702 } 703 704 unsigned index = en.index(); 705 if (!tileSizes.empty()) { 706 tilingOptions.setTileSizeComputationFunction( 707 [&, index](OpBuilder &b, Operation *) { 708 SmallVector<Value, 4> sizes; 709 sizes.reserve(tileSizes.size()); 710 unsigned dynamicIdx = 0; 711 for (OpFoldResult ofr : getMixedSizes()) { 712 if (auto attr = ofr.dyn_cast<Attribute>()) { 713 sizes.push_back(b.create<arith::ConstantIndexOp>( 714 getLoc(), attr.cast<IntegerAttr>().getInt())); 715 } else { 716 sizes.push_back( 717 dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); 718 } 719 } 720 return sizes; 721 }); 722 } 723 724 tilingOptions.setInterchange(extractUIntArray(getInterchange())); 725 LinalgTilingPattern pattern(getContext(), tilingOptions); 726 SimpleRewriter rewriter(linalgOp.getContext()); 727 FailureOr<TiledLinalgOp> tiledOp = 728 pattern.returningMatchAndRewrite(linalgOp, rewriter); 729 if (failed(tiledOp)) 730 return DiagnosedSilenceableFailure::definiteFailure(); 731 732 tiled.push_back(tiledOp->op); 733 for (const auto &en2 : llvm::enumerate(tiledOp->loops)) 734 loops[en2.index()].push_back(en2.value()); 735 } 736 737 transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled); 738 for (const auto &en : llvm::enumerate(loops)) 739 transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value()); 740 741 return DiagnosedSilenceableFailure::success(); 742 } 743 744 SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() { 745 ValueRange dynamic = getDynamicSizes(); 746 SmallVector<int64_t> tileSizes = extractI64Array(getStaticSizes()); 747 SmallVector<OpFoldResult> results; 748 results.reserve(tileSizes.size()); 749 unsigned dynamicPos = 0; 750 Builder builder(getContext()); 751 for (int64_t size : tileSizes) { 752 if (size == ShapedType::kDynamicSize) { 753 results.push_back(dynamic[dynamicPos++]); 754 } else { 755 results.push_back(builder.getIndexAttr(size)); 756 } 757 } 758 return results; 759 } 760 761 ParseResult transform::TileOp::parse(OpAsmParser &parser, 762 OperationState &result) { 763 OpAsmParser::UnresolvedOperand target; 764 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes; 765 ArrayAttr staticSizes; 766 auto pdlOperationType = pdl::OperationType::get(parser.getContext()); 767 if (parser.parseOperand(target) || 768 parser.resolveOperand(target, pdlOperationType, result.operands) || 769 parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) || 770 parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || 771 parser.parseOptionalAttrDict(result.attributes)) 772 return ParseResult::failure(); 773 774 result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); 775 size_t numExpectedLoops = 776 staticSizes.size() - llvm::count(extractI64Array(staticSizes), 0); 777 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType)); 778 return success(); 779 } 780 781 void TileOp::print(OpAsmPrinter &p) { 782 p << ' ' << getTarget(); 783 printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(), 784 getStaticSizes()); 785 p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); 786 } 787 788 void transform::TileOp::getEffects( 789 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 790 consumesHandle(getTarget(), effects); 791 onlyReadsHandle(getDynamicSizes(), effects); 792 producesHandle(getTiledLinalgOp(), effects); 793 producesHandle(getLoops(), effects); 794 modifiesPayload(effects); 795 } 796 797 //===----------------------------------------------------------------------===// 798 // VectorizeOp 799 //===----------------------------------------------------------------------===// 800 801 DiagnosedSilenceableFailure 802 transform::VectorizeOp::applyToOne(Operation *target, 803 SmallVectorImpl<Operation *> &results, 804 transform::TransformState &state) { 805 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 806 auto diag = this->emitOpError("requires isolated-from-above targets"); 807 diag.attachNote(target->getLoc()) << "non-isolated target"; 808 return DiagnosedSilenceableFailure::definiteFailure(); 809 } 810 811 MLIRContext *ctx = getContext(); 812 RewritePatternSet patterns(ctx); 813 patterns.add<LinalgVectorizationPattern>(ctx); 814 815 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 816 vector::populateVectorReductionToContractPatterns(patterns); 817 patterns.add<linalg::LinalgCopyVTRForwardingPattern, 818 linalg::LinalgCopyVTWForwardingPattern>(ctx, 819 /*benefit=*/2); 820 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); 821 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); 822 if (getVectorizePadding()) 823 linalg::populatePadOpVectorizationPatterns(patterns); 824 825 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) 826 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 827 828 results.push_back(target); 829 return DiagnosedSilenceableFailure(success()); 830 } 831 832 //===----------------------------------------------------------------------===// 833 // Transform op registration 834 //===----------------------------------------------------------------------===// 835 836 namespace { 837 /// Registers new ops and declares PDL as dependent dialect since the additional 838 /// ops are using PDL types for operands and results. 839 class LinalgTransformDialectExtension 840 : public transform::TransformDialectExtension< 841 LinalgTransformDialectExtension> { 842 public: 843 LinalgTransformDialectExtension() { 844 declareDependentDialect<AffineDialect>(); 845 declareDependentDialect<arith::ArithmeticDialect>(); 846 declareDependentDialect<pdl::PDLDialect>(); 847 declareDependentDialect<scf::SCFDialect>(); 848 declareDependentDialect<vector::VectorDialect>(); 849 registerTransformOps< 850 #define GET_OP_LIST 851 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 852 >(); 853 } 854 }; 855 } // namespace 856 857 #define GET_OP_CLASSES 858 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 859 860 void mlir::linalg::registerTransformDialectExtension( 861 DialectRegistry ®istry) { 862 registry.addExtensions<LinalgTransformDialectExtension>(); 863 } 864