1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/PDL/IR/PDL.h" 16 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 19 #include "mlir/Interfaces/TilingInterface.h" 20 #include "mlir/Parser/Parser.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 #include "llvm/ADT/StringSet.h" 23 24 using namespace mlir; 25 using namespace mlir::linalg; 26 using namespace mlir::transform; 27 28 /// Extracts a vector of unsigned from an array attribute. Asserts if the 29 /// attribute contains values other than intergers. May truncate. 30 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) { 31 SmallVector<unsigned> result; 32 result.reserve(attr.size()); 33 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 34 result.push_back(value.getZExtValue()); 35 return result; 36 } 37 38 namespace { 39 /// A simple pattern rewriter that implements no special logic. 40 class SimpleRewriter : public PatternRewriter { 41 public: 42 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 43 }; 44 } // namespace 45 46 /// Attempts to apply the pattern specified as template argument to the given 47 /// operation. The pattern is expected to have a `returningMatchAndRewrite` 48 /// function that returns the "main" result or failure. Returns failure if the 49 /// pattern failed to apply. Extra arguments are forwarded to the pattern 50 /// constructor. 51 template <typename PatternTy, typename... Args> 52 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { 53 // Check if the given operation has the type expected by the pattern. 54 using OpTy = typename llvm::function_traits< 55 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; 56 auto op = dyn_cast<OpTy>(operation); 57 if (!op) 58 return failure(); 59 60 // Apply the pattern directly to the op. 61 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); 62 SimpleRewriter rewriter(operation->getContext()); 63 rewriter.setInsertionPoint(operation); 64 auto result = pattern.returningMatchAndRewrite(op, rewriter); 65 if (failed(result)) 66 return failure(); 67 return cast<LinalgOp>(result->getOperation()); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // DecomposeOp 72 //===----------------------------------------------------------------------===// 73 74 DiagnosedSilenceableFailure 75 transform::DecomposeOp::applyToOne(linalg::LinalgOp target, 76 SmallVectorImpl<Operation *> &results, 77 transform::TransformState &state) { 78 FailureOr<LinalgOp> windowed = 79 tryApply<DownscaleSizeOneWindowed2DConvolution>(target); 80 if (succeeded(windowed)) { 81 results.push_back(*windowed); 82 return DiagnosedSilenceableFailure(success()); 83 } 84 FailureOr<LinalgOp> depthwise = 85 tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target); 86 if (succeeded(depthwise)) { 87 results.push_back(*depthwise); 88 return DiagnosedSilenceableFailure(success()); 89 } 90 results.assign(1, nullptr); 91 return emitDefaultSilenceableFailure(target); 92 } 93 94 //===----------------------------------------------------------------------===// 95 // FuseOp 96 //===----------------------------------------------------------------------===// 97 98 /// Apply a tiling transformation to all payload ops and store both the 99 /// tiled operation as well as the created tile loops. 100 static LogicalResult 101 applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps, 102 unsigned numLoops, 103 transform::TransformResults &transformResults, 104 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) { 105 SmallVector<Operation *> tiledLinalgOps; 106 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 107 for (unsigned int i = 0; i < numLoops; ++i) 108 loopOps[i].reserve(payloadOps.size()); 109 110 for (Operation *target : payloadOps) { 111 auto linalgOp = dyn_cast<linalg::LinalgOp>(target); 112 if (!linalgOp) 113 return transformOp->emitError("only LinalgOps are supported"); 114 115 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp); 116 if (failed(tiled)) 117 return failure(); 118 119 tiledLinalgOps.push_back(tiled->op); 120 if (tiled->loops.size() != numLoops) 121 // Not enough loops were generated. This usually means that the input size 122 // was smaller than the tiling size. 123 // TODO: LinalgTilingPattern should return failure(). 124 return failure(); 125 for (unsigned int i = 0; i < numLoops; ++i) 126 loopOps[i].push_back(tiled->loops[i]); 127 } 128 129 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 130 for (unsigned int i = 0; i < numLoops; ++i) 131 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 132 return success(); 133 } 134 135 /// Parse a tiling-like operation that returns the tiled op as well as the 136 /// created tile loops. The function counts the non-zero tile sizes to compute 137 /// the number of results. 138 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, 139 StringRef sizesAttrName) { 140 OpAsmParser::UnresolvedOperand targetOperand; 141 SMLoc opLoc = parser.getCurrentLocation(); 142 if (parser.parseOperand(targetOperand) || 143 parser.parseOptionalAttrDict(result.attributes)) 144 return failure(); 145 Attribute sizesAttr = result.attributes.get(sizesAttrName); 146 if (!sizesAttr) 147 return parser.emitError(opLoc) 148 << "expected '" << sizesAttrName << "' attribute"; 149 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); 150 if (!sizesArrayAttr) 151 return parser.emitError(opLoc) 152 << "'" << sizesAttrName << "' attribute must be an array"; 153 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); 154 size_t numExpectedLoops = 155 sizesArrayAttr.size() - 156 llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0); 157 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); 158 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) 159 return failure(); 160 return success(); 161 } 162 163 DiagnosedSilenceableFailure 164 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, 165 mlir::transform::TransformState &state) { 166 LinalgTilingAndFusionOptions fusionOptions; 167 fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes()); 168 fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange()); 169 170 LogicalResult result = applyTilingToAll( 171 getOperation(), state.getPayloadOps(getTarget()), 172 fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0), 173 transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> { 174 LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); 175 SimpleRewriter rewriter(getContext()); 176 rewriter.setInsertionPoint(linalgOp); 177 FailureOr<TileLoopNest> tileLoopNest = 178 pattern.returningMatchAndRewrite(linalgOp, rewriter); 179 if (failed(tileLoopNest)) 180 return failure(); 181 182 TiledLinalgOp tiledLinalgOp; 183 tiledLinalgOp.op = tileLoopNest->getRootOp(); 184 tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(), 185 tileLoopNest->getLoopOps().end()}; 186 return tiledLinalgOp; 187 }); 188 return DiagnosedSilenceableFailure(result); 189 } 190 191 ParseResult transform::FuseOp::parse(OpAsmParser &parser, 192 OperationState &result) { 193 return parseTileLikeOp( 194 parser, result, 195 transform::FuseOp::getTileSizesAttrName(result.name).getValue()); 196 } 197 198 void transform::FuseOp::print(OpAsmPrinter &p) { 199 p << ' '; 200 p << getTarget(); 201 p.printOptionalAttrDict((*this)->getAttrs()); 202 } 203 204 LogicalResult transform::FuseOp::verify() { 205 SmallVector<int64_t> permutation = 206 extractFromI64ArrayAttr(getTileInterchange()); 207 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); 208 if (!std::is_permutation(sequence.begin(), sequence.end(), 209 permutation.begin(), permutation.end())) { 210 return emitOpError() << "expects interchange to be a permutation, found " 211 << getTileInterchange(); 212 } 213 return success(); 214 } 215 216 //===----------------------------------------------------------------------===// 217 // GeneralizeOp 218 //===----------------------------------------------------------------------===// 219 220 DiagnosedSilenceableFailure 221 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, 222 SmallVectorImpl<Operation *> &results, 223 transform::TransformState &state) { 224 // Exit early if no transformation is needed. 225 if (isa<GenericOp>(target)) { 226 results.push_back(target); 227 return DiagnosedSilenceableFailure(success()); 228 } 229 FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target); 230 if (succeeded(generic)) { 231 results.push_back(generic->getOperation()); 232 return DiagnosedSilenceableFailure(success()); 233 } 234 results.assign(1, nullptr); 235 return emitDefaultSilenceableFailure(target); 236 } 237 238 //===----------------------------------------------------------------------===// 239 // InterchangeOp 240 //===----------------------------------------------------------------------===// 241 242 DiagnosedSilenceableFailure 243 transform::InterchangeOp::applyToOne(linalg::GenericOp target, 244 SmallVectorImpl<Operation *> &results, 245 transform::TransformState &state) { 246 SmallVector<unsigned> interchangeVector = 247 extractUIntArray(getIteratorInterchange()); 248 // Exit early if no transformation is needed. 249 if (interchangeVector.empty()) { 250 results.push_back(target); 251 return DiagnosedSilenceableFailure(success()); 252 } 253 SimpleRewriter rewriter(target->getContext()); 254 FailureOr<GenericOp> res = 255 interchangeGenericOp(rewriter, target, interchangeVector); 256 if (failed(res)) 257 return DiagnosedSilenceableFailure::definiteFailure(); 258 results.push_back(res->getOperation()); 259 return DiagnosedSilenceableFailure(success()); 260 } 261 262 LogicalResult transform::InterchangeOp::verify() { 263 SmallVector<unsigned> permutation = 264 extractUIntArray(getIteratorInterchange()); 265 auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size())); 266 if (!std::is_permutation(sequence.begin(), sequence.end(), 267 permutation.begin(), permutation.end())) { 268 return emitOpError() 269 << "expects iterator_interchange to be a permutation, found " 270 << getIteratorInterchange(); 271 } 272 return success(); 273 } 274 275 //===---------------------------------------------------------------------===// 276 // MatchOp 277 //===---------------------------------------------------------------------===// 278 279 LogicalResult transform::MatchOp::verify() { 280 bool opXorIface = getOps().hasValue() ^ getInterface().hasValue(); 281 if (!opXorIface) 282 return this->emitOpError( 283 "requires a either a match_op or a match_interface attribute (but not " 284 "both)"); 285 return success(); 286 } 287 288 DiagnosedSilenceableFailure 289 transform::MatchOp::apply(transform::TransformResults &results, 290 transform::TransformState &state) { 291 llvm::StringSet<> strs; 292 if (getOps().hasValue()) 293 strs.insert(getOps()->getAsValueRange<StringAttr>().begin(), 294 getOps()->getAsValueRange<StringAttr>().end()); 295 296 ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); 297 if (payloadOps.size() != 1) 298 return DiagnosedSilenceableFailure( 299 this->emitOpError("requires exactly one target handle")); 300 301 SmallVector<Operation *> res; 302 303 auto matchFun = [&](Operation *op) { 304 if (strs.contains(op->getName().getStringRef())) 305 res.push_back(op); 306 // Interfaces cannot be matched by name, just by ID. 307 // So we specifically encode the interfaces we care about for this op. 308 if (getInterface().hasValue()) { 309 auto iface = getInterface().getValue(); 310 if (iface == transform::MatchInterfaceEnum::LinalgOp && 311 isa<linalg::LinalgOp>(op)) 312 res.push_back(op); 313 if (iface == transform::MatchInterfaceEnum::TilingInterface && 314 isa<TilingInterface>(op)) 315 res.push_back(op); 316 } 317 }; 318 319 payloadOps.front()->walk(matchFun); 320 results.set(getResult().cast<OpResult>(), res); 321 return DiagnosedSilenceableFailure(success()); 322 } 323 324 ParseResult transform::MatchOp::parse(OpAsmParser &parser, 325 OperationState &result) { 326 // Parse 'match_op' or 'interface' clause. 327 if (succeeded(parser.parseOptionalKeyword("ops"))) { 328 ArrayAttr opsAttr; 329 if (parser.parseLBrace() || 330 parser.parseCustomAttributeWithFallback( 331 opsAttr, parser.getBuilder().getType<NoneType>(), "ops", 332 result.attributes) || 333 parser.parseRBrace()) 334 return failure(); 335 } else if (succeeded(parser.parseOptionalKeyword("interface"))) { 336 if (parser.parseLBrace()) 337 return failure(); 338 StringRef attrStr; 339 auto loc = parser.getCurrentLocation(); 340 if (parser.parseKeyword(&attrStr)) 341 return failure(); 342 auto interfaceEnum = transform::symbolizeMatchInterfaceEnum(attrStr); 343 if (!interfaceEnum) 344 return parser.emitError(loc, "invalid ") 345 << "match_interface attribute specification: \"" << attrStr << '"'; 346 transform::MatchInterfaceEnumAttr match_interfaceAttr = 347 transform::MatchInterfaceEnumAttr::get(parser.getBuilder().getContext(), 348 interfaceEnum.value()); 349 result.addAttribute("interface", match_interfaceAttr); 350 if (parser.parseRBrace()) 351 return failure(); 352 } else { 353 auto loc = parser.getCurrentLocation(); 354 return parser.emitError(loc, "expected ops or interface"); 355 } 356 357 OpAsmParser::UnresolvedOperand targetRawOperands[1]; 358 ArrayRef<OpAsmParser::UnresolvedOperand> targetOperands(targetRawOperands); 359 if (parser.parseKeyword("in") || parser.parseOperand(targetRawOperands[0]) || 360 parser.parseOptionalAttrDict(result.attributes)) 361 return failure(); 362 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); 363 result.addTypes(pdlOpType); 364 if (parser.resolveOperands(targetOperands, pdlOpType, result.operands)) 365 return failure(); 366 return success(); 367 } 368 369 void transform::MatchOp::print(OpAsmPrinter &p) { 370 if ((*this)->getAttr("ops")) { 371 p << " ops{"; 372 p.printAttributeWithoutType(getOpsAttr()); 373 p << "}"; 374 } 375 if ((*this)->getAttr("interface")) { 376 p << " interface{" << stringifyMatchInterfaceEnum(*getInterface()) << "}"; 377 } 378 p << " in " << getTarget(); 379 p.printOptionalAttrDict((*this)->getAttrs(), 380 /*elidedAttrs=*/{"ops", "interface"}); 381 } 382 383 //===---------------------------------------------------------------------===// 384 // MultiTileSizesOp 385 //===---------------------------------------------------------------------===// 386 387 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( 388 LinalgOp target, SmallVector<Operation *> &results, TransformState &state) { 389 OpBuilder builder(target.getContext()); 390 builder.setInsertionPoint(target); 391 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); 392 OpFoldResult divisor = builder.getIndexAttr(getDivisor()); 393 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes( 394 builder, target, getDimension(), targetSize, divisor); 395 if (failed(spec)) { 396 return emitSilenceableError() << "could not generate tile size computation"; 397 } 398 399 AffineExpr s0 = builder.getAffineSymbolExpr(0); 400 AffineExpr s1 = builder.getAffineSymbolExpr(1); 401 Operation *splitPoint = 402 makeComposedAffineApply(builder, target.getLoc(), s0 * s1, 403 {spec->lowTileSize, spec->lowTripCount}); 404 Operation *lowTileSize = spec->lowTileSize.getDefiningOp(); 405 Operation *highTileSize = spec->highTileSize.getDefiningOp(); 406 assert(lowTileSize && highTileSize && splitPoint && 407 "tile sizes are not produced by operations"); 408 results.reserve(results.size() + 3); 409 results.push_back(lowTileSize); 410 results.push_back(highTileSize); 411 results.push_back(splitPoint); 412 return DiagnosedSilenceableFailure::success(); 413 } 414 415 void transform::MultiTileSizesOp::getEffects( 416 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 417 onlyReadsHandle(getTarget(), effects); 418 producesHandle(getResults(), effects); 419 modifiesPayload(effects); 420 } 421 422 //===---------------------------------------------------------------------===// 423 // PadOp 424 //===---------------------------------------------------------------------===// 425 426 DiagnosedSilenceableFailure 427 transform::PadOp::applyToOne(linalg::LinalgOp target, 428 SmallVectorImpl<Operation *> &results, 429 transform::TransformState &state) { 430 // Convert the integer packing flags to booleans. 431 SmallVector<bool> packPaddings; 432 for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings())) 433 packPaddings.push_back(static_cast<bool>(packPadding)); 434 435 // Convert the padding values to attributes. 436 SmallVector<Attribute> paddingValues; 437 for (auto const &it : 438 llvm::zip(getPaddingValues(), target->getOperandTypes())) { 439 Attribute attr = std::get<0>(it); 440 Type elementType = getElementTypeOrSelf(std::get<1>(it)); 441 // Try to parse string attributes to obtain an attribute of element type. 442 if (auto stringAttr = attr.dyn_cast<StringAttr>()) { 443 paddingValues.push_back( 444 parseAttribute(attr.cast<StringAttr>(), elementType)); 445 if (!paddingValues.back()) { 446 auto diag = this->emitOpError("expects a padding that parses to ") 447 << elementType << ", got " << std::get<0>(it); 448 diag.attachNote(target.getLoc()) << "when applied to this op"; 449 return DiagnosedSilenceableFailure::definiteFailure(); 450 } 451 continue; 452 } 453 // Otherwise, add the attribute directly. 454 if (attr.getType() != elementType) { 455 auto diag = this->emitOpError("expects a padding value of type ") 456 << elementType << ", got " << attr; 457 diag.attachNote(target.getLoc()) << "when applied to this op"; 458 return DiagnosedSilenceableFailure::definiteFailure(); 459 } 460 paddingValues.push_back(attr); 461 } 462 463 // Extract the transpose vectors. 464 SmallVector<SmallVector<int64_t>> transposePaddings; 465 for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>()) 466 transposePaddings.push_back( 467 extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>())); 468 469 LinalgPaddingOptions paddingOptions; 470 paddingOptions.setPaddingValues(paddingValues); 471 paddingOptions.setPaddingDimensions( 472 extractFromI64ArrayAttr(getPaddingDimensions())); 473 paddingOptions.setPackPaddings(packPaddings); 474 paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings())); 475 paddingOptions.setTransposePaddings(transposePaddings); 476 477 FailureOr<LinalgOp> result = 478 tryApply<LinalgPaddingPattern>(target, paddingOptions); 479 if (succeeded(result)) { 480 results.push_back(result->getOperation()); 481 return DiagnosedSilenceableFailure(success()); 482 } 483 484 results.assign(1, nullptr); 485 return emitDefaultSilenceableFailure(target); 486 } 487 488 LogicalResult transform::PadOp::verify() { 489 SmallVector<int64_t> packPaddings = 490 extractFromI64ArrayAttr(getPackPaddings()); 491 if (any_of(packPaddings, [](int64_t packPadding) { 492 return packPadding != 0 && packPadding != 1; 493 })) { 494 return emitOpError() 495 << "expects pack_paddings to contain booleans (0/1), found " 496 << getPackPaddings(); 497 } 498 499 SmallVector<int64_t> paddingDimensions = 500 extractFromI64ArrayAttr(getPaddingDimensions()); 501 if (any_of(paddingDimensions, 502 [](int64_t paddingDimension) { return paddingDimension < 0; })) { 503 return emitOpError() 504 << "expects padding_dimensions to contain positive integers, found " 505 << getPaddingDimensions(); 506 } 507 508 SmallVector<int64_t> hoistPaddings = 509 extractFromI64ArrayAttr(getHoistPaddings()); 510 if (any_of(hoistPaddings, 511 [](int64_t hoistPadding) { return hoistPadding < 0; })) { 512 return emitOpError() 513 << "expects hoist_paddings to contain positive integers, found " 514 << getHoistPaddings(); 515 } 516 517 ArrayAttr transposes = getTransposePaddings(); 518 for (Attribute attr : transposes) { 519 SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr); 520 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); 521 if (!std::is_permutation(sequence.begin(), sequence.end(), 522 transpose.begin(), transpose.end())) { 523 return emitOpError() 524 << "expects transpose_paddings to be a permutation, found " 525 << attr; 526 } 527 } 528 return success(); 529 } 530 531 //===----------------------------------------------------------------------===// 532 // PromoteOp 533 //===----------------------------------------------------------------------===// 534 535 DiagnosedSilenceableFailure 536 transform::PromoteOp::applyToOne(linalg::LinalgOp target, 537 SmallVectorImpl<Operation *> &results, 538 transform::TransformState &state) { 539 LinalgPromotionOptions promotionOptions; 540 if (!getOperandsToPromote().empty()) 541 promotionOptions = promotionOptions.setOperandsToPromote( 542 extractFromI64ArrayAttr(getOperandsToPromote())); 543 if (getUseFullTilesByDefault()) 544 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( 545 getUseFullTilesByDefault()); 546 if (getUseAlloca()) 547 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca()); 548 if (!getUseFullTileBuffers().empty()) 549 promotionOptions = promotionOptions.setUseFullTileBuffers( 550 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>())); 551 if (getAlignment().has_value()) 552 promotionOptions = promotionOptions.setAlignment(*getAlignment()); 553 554 if (failed(promoteSubviewsPrecondition(target, promotionOptions))) 555 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 556 557 SimpleRewriter rewriter(target->getContext()); 558 rewriter.setInsertionPoint(target); 559 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions); 560 if (failed(res)) 561 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 562 results.push_back(target); 563 return DiagnosedSilenceableFailure(success()); 564 } 565 566 //===----------------------------------------------------------------------===// 567 // ScalarizeOp 568 //===----------------------------------------------------------------------===// 569 570 DiagnosedSilenceableFailure 571 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, 572 SmallVectorImpl<Operation *> &results, 573 transform::TransformState &state) { 574 LinalgTilingOptions tilingOptions; 575 tilingOptions.scalarizeDynamicDims(); 576 // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile 577 // sizes and asserts that it is not already set. 578 SmallVector<int64_t> emptyTileSizes; 579 LinalgTilingPattern pattern(getContext(), tilingOptions); 580 SimpleRewriter rewriter(getContext()); 581 rewriter.setInsertionPoint(target); 582 FailureOr<TiledLinalgOp> result = 583 pattern.returningMatchAndRewrite(target, rewriter); 584 if (failed(result)) 585 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 586 587 results.push_back(result->op); 588 return DiagnosedSilenceableFailure(success()); 589 } 590 591 //===----------------------------------------------------------------------===// 592 // SplitOp 593 //===----------------------------------------------------------------------===// 594 595 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, 596 TransformState &state) { 597 // Collect the dynamic split points if provided. 598 ArrayRef<Operation *> payload = state.getPayloadOps(getTarget()); 599 SimpleRewriter rewriter(getContext()); 600 SmallVector<OpFoldResult> splitPoints; 601 splitPoints.reserve(payload.size()); 602 if (getDynamicSplitPoint()) { 603 auto diag = DiagnosedSilenceableFailure::success(); 604 splitPoints = llvm::to_vector(llvm::map_range( 605 state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { 606 if (op->getNumResults() != 1 || 607 !op->getResult(0).getType().isIndex()) { 608 diag = emitSilenceableError() 609 << "expected dynamic split point handle to point to a " 610 "single-result index-typed op"; 611 diag.attachNote(op->getLoc()) << "dynamic split point"; 612 } 613 return OpFoldResult(op->getResult(0)); 614 })); 615 if (!diag.succeeded()) 616 return diag; 617 618 if (splitPoints.size() != payload.size()) { 619 emitError() << "expected the dynamic split point handle to point to as " 620 "many operations (" 621 << splitPoints.size() << ") as the target handle (" 622 << payload.size() << ")"; 623 return DiagnosedSilenceableFailure::definiteFailure(); 624 } 625 } else { 626 splitPoints.resize(payload.size(), 627 rewriter.getIndexAttr(getStaticSplitPoint())); 628 } 629 630 // Split each target operation. 631 SmallVector<Operation *> first, second; 632 for (const auto &pair : llvm::zip(payload, splitPoints)) { 633 Operation *target = std::get<0>(pair); 634 auto linalgOp = dyn_cast<LinalgOp>(target); 635 if (!linalgOp) { 636 auto diag = emitSilenceableError() << "only applies to structured ops"; 637 diag.attachNote(target->getLoc()) << "target op"; 638 return diag; 639 } 640 641 if (getDimension() >= linalgOp.getNumLoops()) { 642 auto diag = emitSilenceableError() << "dimension " << getDimension() 643 << " does not exist in target op"; 644 diag.attachNote(target->getLoc()) << "target op"; 645 return diag; 646 } 647 648 rewriter.setInsertionPoint(linalgOp); 649 std::tie(first.emplace_back(), second.emplace_back()) = 650 linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair)); 651 } 652 653 results.set(getFirst().cast<OpResult>(), first); 654 results.set(getSecond().cast<OpResult>(), second); 655 return DiagnosedSilenceableFailure::success(); 656 } 657 658 void SplitOp::getEffects( 659 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 660 consumesHandle(getTarget(), effects); 661 if (getDynamicSplitPoint()) 662 onlyReadsHandle(getDynamicSplitPoint(), effects); 663 producesHandle(getResults(), effects); 664 modifiesPayload(effects); 665 } 666 667 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { 668 OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; 669 IntegerAttr staticSplitPoint; 670 auto pdlOperationType = 671 pdl::OperationType::get(parser.getBuilder().getContext()); 672 if (parser.parseOperand(target) || 673 parser.resolveOperand(target, pdlOperationType, result.operands) || 674 parser.parseKeyword("after")) 675 return failure(); 676 677 OptionalParseResult dynamicPointParseResult = 678 parser.parseOptionalOperand(dynamicSplitPoint); 679 if (!dynamicPointParseResult.hasValue()) { 680 int64_t staticSplitPointValue; 681 if (failed(parser.parseInteger(staticSplitPointValue))) 682 return failure(); 683 684 staticSplitPoint = 685 parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); 686 } else { 687 if (failed(*dynamicPointParseResult) || 688 parser.resolveOperand(dynamicSplitPoint, pdlOperationType, 689 result.operands)) { 690 return failure(); 691 } 692 693 staticSplitPoint = 694 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize); 695 } 696 697 result.addAttribute( 698 SplitOp::getStaticSplitPointAttrName(result.name).getValue(), 699 staticSplitPoint); 700 if (failed(parser.parseOptionalAttrDict(result.attributes))) 701 return failure(); 702 703 result.addTypes({pdlOperationType, pdlOperationType}); 704 return success(); 705 } 706 707 void SplitOp::print(OpAsmPrinter &printer) { 708 printer << " " << getTarget() << " after "; 709 int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint()); 710 if (staticSplitSize != ShapedType::kDynamicSize) 711 printer << staticSplitSize; 712 else 713 printer << getDynamicSplitPoint(); 714 printer << " "; 715 printer.printOptionalAttrDict(getOperation()->getAttrs(), 716 {getStaticSplitPointAttrName()}); 717 } 718 719 LogicalResult SplitOp::verify() { 720 if ((static_cast<int64_t>(getStaticSplitPoint()) != 721 ShapedType::kDynamicSize) ^ 722 (getDynamicSplitPoint() == nullptr)) { 723 return emitOpError() 724 << "expects either a dynamic or a static split point to be provided"; 725 } 726 return success(); 727 } 728 729 //===----------------------------------------------------------------------===// 730 // SplitReductionOp 731 //===----------------------------------------------------------------------===// 732 733 DiagnosedSilenceableFailure 734 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, 735 SmallVectorImpl<Operation *> &results, 736 transform::TransformState &state) { 737 ControlSplitReductionFn splitFn = [&](LinalgOp) { 738 return std::pair<int64_t, unsigned>(getSplitFactor(), 739 getInsertSplitDimension()); 740 }; 741 SimpleRewriter rewriter(getContext()); 742 rewriter.setInsertionPoint(target); 743 FailureOr<SplitReductionResult> splitResult = 744 (getUseScalingAlgorithm()) 745 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) 746 : splitReduction(rewriter, target, splitFn, getUseAlloc()); 747 if (failed(splitResult)) 748 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 749 750 results.push_back(splitResult->initOrAlloc); 751 results.push_back(splitResult->fillOp); 752 results.push_back(splitResult->splitLinalgOp); 753 results.push_back(splitResult->resultCombiningLinalgOp); 754 return DiagnosedSilenceableFailure(success()); 755 } 756 757 //===----------------------------------------------------------------------===// 758 // TileOp 759 //===----------------------------------------------------------------------===// 760 761 DiagnosedSilenceableFailure 762 transform::TileOp::apply(TransformResults &transformResults, 763 TransformState &state) { 764 LinalgTilingOptions tilingOptions; 765 SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes()); 766 767 ArrayRef<Operation *> targets = state.getPayloadOps(getTarget()); 768 SmallVector<ArrayRef<Operation *>> dynamicSizeProducers; 769 dynamicSizeProducers.reserve(getDynamicSizes().size()); 770 for (Value dynamicSizeProducerHandle : getDynamicSizes()) { 771 dynamicSizeProducers.push_back( 772 state.getPayloadOps(dynamicSizeProducerHandle)); 773 774 if (dynamicSizeProducers.back().size() != targets.size()) { 775 DiagnosedSilenceableFailure diag = 776 emitSilenceableError() 777 << "expected as many dynamic size-producing operations (" 778 << dynamicSizeProducers.back().size() << ") as target ops (" 779 << targets.size() << ")"; 780 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; 781 return diag; 782 } 783 784 for (Operation *op : dynamicSizeProducers.back()) { 785 if (op->getNumResults() == 1 && 786 op->getResult(0).getType().isa<IndexType>()) 787 continue; 788 DiagnosedSilenceableFailure diag = 789 emitSilenceableError() << "expected sizes to be produced by ops " 790 "with a single index-type result"; 791 diag.attachNote(op->getLoc()) << "size producer op"; 792 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; 793 return diag; 794 } 795 } 796 797 SmallVector<Operation *> tiled; 798 SmallVector<SmallVector<Operation *, 4>, 4> loops; 799 loops.resize(getLoops().size()); 800 for (auto &en : llvm::enumerate(targets)) { 801 auto linalgOp = dyn_cast<LinalgOp>(en.value()); 802 if (!linalgOp) { 803 DiagnosedSilenceableFailure diag = emitSilenceableError() 804 << "only linalg ops are supported"; 805 diag.attachNote(en.value()->getLoc()) << "target op"; 806 return diag; 807 } 808 809 unsigned index = en.index(); 810 if (!tileSizes.empty()) { 811 tilingOptions.setTileSizeComputationFunction( 812 [&, index](OpBuilder &b, Operation *) { 813 SmallVector<Value, 4> sizes; 814 sizes.reserve(tileSizes.size()); 815 unsigned dynamicIdx = 0; 816 for (OpFoldResult ofr : getMixedSizes()) { 817 if (auto attr = ofr.dyn_cast<Attribute>()) { 818 sizes.push_back(b.create<arith::ConstantIndexOp>( 819 getLoc(), attr.cast<IntegerAttr>().getInt())); 820 } else { 821 sizes.push_back( 822 dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); 823 } 824 } 825 return sizes; 826 }); 827 } 828 829 tilingOptions.setInterchange(extractUIntArray(getInterchange())); 830 LinalgTilingPattern pattern(getContext(), tilingOptions); 831 SimpleRewriter rewriter(linalgOp.getContext()); 832 FailureOr<TiledLinalgOp> tiledOp = 833 pattern.returningMatchAndRewrite(linalgOp, rewriter); 834 if (failed(tiledOp)) 835 return DiagnosedSilenceableFailure::definiteFailure(); 836 837 tiled.push_back(tiledOp->op); 838 for (const auto &en2 : llvm::enumerate(tiledOp->loops)) 839 loops[en2.index()].push_back(en2.value()); 840 } 841 842 transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled); 843 for (const auto &en : llvm::enumerate(loops)) 844 transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value()); 845 846 return DiagnosedSilenceableFailure::success(); 847 } 848 849 SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() { 850 ValueRange dynamic = getDynamicSizes(); 851 SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes()); 852 SmallVector<OpFoldResult> results; 853 results.reserve(tileSizes.size()); 854 unsigned dynamicPos = 0; 855 Builder builder(getContext()); 856 for (int64_t size : tileSizes) { 857 if (size == ShapedType::kDynamicSize) { 858 results.push_back(dynamic[dynamicPos++]); 859 } else { 860 results.push_back(builder.getIndexAttr(size)); 861 } 862 } 863 return results; 864 } 865 866 ParseResult transform::TileOp::parse(OpAsmParser &parser, 867 OperationState &result) { 868 OpAsmParser::UnresolvedOperand target; 869 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes; 870 ArrayAttr staticSizes; 871 auto pdlOperationType = pdl::OperationType::get(parser.getContext()); 872 if (parser.parseOperand(target) || 873 parser.resolveOperand(target, pdlOperationType, result.operands) || 874 parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) || 875 parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || 876 parser.parseOptionalAttrDict(result.attributes)) 877 return ParseResult::failure(); 878 879 result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); 880 size_t numExpectedLoops = 881 staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0); 882 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType)); 883 return success(); 884 } 885 886 void TileOp::print(OpAsmPrinter &p) { 887 p << ' ' << getTarget(); 888 printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(), 889 getStaticSizes()); 890 p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); 891 } 892 893 void transform::TileOp::getEffects( 894 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 895 consumesHandle(getTarget(), effects); 896 onlyReadsHandle(getDynamicSizes(), effects); 897 producesHandle(getTiledLinalgOp(), effects); 898 producesHandle(getLoops(), effects); 899 modifiesPayload(effects); 900 } 901 902 //===----------------------------------------------------------------------===// 903 // TileToForeachThreadOp 904 //===----------------------------------------------------------------------===// 905 906 DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne( 907 TilingInterface target, SmallVectorImpl<Operation *> &results, 908 transform::TransformState &state) { 909 IRRewriter rewriter(getContext()); 910 rewriter.setInsertionPoint(target); 911 auto maybeThreadDimMappingAttr = getThreadDimMapping(); 912 auto dimMapping = 913 llvm::to_vector(maybeThreadDimMappingAttr 914 ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) 915 : ArrayRef<int64_t>{}); 916 917 FailureOr<ForeachThreadTilingResult> tilingResult = failure(); 918 if (Optional<ArrayAttr> numThreads = getNumThreads()) 919 tilingResult = linalg::tileToForeachThreadOp( 920 rewriter, target, getAsOpFoldResult(*numThreads), dimMapping); 921 922 if (Optional<ArrayAttr> tileSizes = getTileSizes()) 923 tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( 924 rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping); 925 926 if (failed(tilingResult)) 927 return emitDefaultSilenceableFailure(target); 928 rewriter.replaceOp(target, tilingResult->tileOp->getResults()); 929 results.assign({tilingResult->tileOp, tilingResult->tiledOp}); 930 return DiagnosedSilenceableFailure(success()); 931 } 932 933 //===----------------------------------------------------------------------===// 934 // VectorizeOp 935 //===----------------------------------------------------------------------===// 936 937 DiagnosedSilenceableFailure 938 transform::VectorizeOp::applyToOne(Operation *target, 939 SmallVectorImpl<Operation *> &results, 940 transform::TransformState &state) { 941 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 942 auto diag = this->emitOpError("requires isolated-from-above targets"); 943 diag.attachNote(target->getLoc()) << "non-isolated target"; 944 return DiagnosedSilenceableFailure::definiteFailure(); 945 } 946 947 MLIRContext *ctx = getContext(); 948 RewritePatternSet patterns(ctx); 949 patterns.add<LinalgVectorizationPattern>(ctx); 950 951 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 952 vector::populateVectorReductionToContractPatterns(patterns); 953 patterns.add<linalg::LinalgCopyVTRForwardingPattern, 954 linalg::LinalgCopyVTWForwardingPattern>(ctx, 955 /*benefit=*/2); 956 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); 957 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); 958 if (getVectorizePadding()) 959 linalg::populatePadOpVectorizationPatterns(patterns); 960 961 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) 962 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 963 964 results.push_back(target); 965 return DiagnosedSilenceableFailure(success()); 966 } 967 968 //===----------------------------------------------------------------------===// 969 // Transform op registration 970 //===----------------------------------------------------------------------===// 971 972 namespace { 973 /// Registers new ops and declares PDL as dependent dialect since the additional 974 /// ops are using PDL types for operands and results. 975 class LinalgTransformDialectExtension 976 : public transform::TransformDialectExtension< 977 LinalgTransformDialectExtension> { 978 public: 979 LinalgTransformDialectExtension() { 980 declareDependentDialect<AffineDialect>(); 981 declareDependentDialect<arith::ArithmeticDialect>(); 982 declareDependentDialect<pdl::PDLDialect>(); 983 declareDependentDialect<scf::SCFDialect>(); 984 declareDependentDialect<vector::VectorDialect>(); 985 registerTransformOps< 986 #define GET_OP_LIST 987 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 988 >(); 989 } 990 }; 991 } // namespace 992 993 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" 994 995 #define GET_OP_CLASSES 996 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 997 998 void mlir::linalg::registerTransformDialectExtension( 999 DialectRegistry ®istry) { 1000 registry.addExtensions<LinalgTransformDialectExtension>(); 1001 } 1002