1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/PDL/IR/PDL.h" 16 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 19 #include "mlir/Interfaces/TilingInterface.h" 20 #include "mlir/Parser/Parser.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 #include "llvm/ADT/StringSet.h" 23 24 using namespace mlir; 25 using namespace mlir::linalg; 26 using namespace mlir::transform; 27 28 /// Extracts a vector of unsigned from an array attribute. Asserts if the 29 /// attribute contains values other than intergers. May truncate. 30 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) { 31 SmallVector<unsigned> result; 32 result.reserve(attr.size()); 33 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 34 result.push_back(value.getZExtValue()); 35 return result; 36 } 37 38 namespace { 39 /// A simple pattern rewriter that implements no special logic. 40 class SimpleRewriter : public PatternRewriter { 41 public: 42 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 43 }; 44 } // namespace 45 46 /// Attempts to apply the pattern specified as template argument to the given 47 /// operation. The pattern is expected to have a `returningMatchAndRewrite` 48 /// function that returns the "main" result or failure. Returns failure if the 49 /// pattern failed to apply. Extra arguments are forwarded to the pattern 50 /// constructor. 51 template <typename PatternTy, typename... Args> 52 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { 53 // Check if the given operation has the type expected by the pattern. 54 using OpTy = typename llvm::function_traits< 55 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; 56 auto op = dyn_cast<OpTy>(operation); 57 if (!op) 58 return failure(); 59 60 // Apply the pattern directly to the op. 61 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); 62 SimpleRewriter rewriter(operation->getContext()); 63 rewriter.setInsertionPoint(operation); 64 auto result = pattern.returningMatchAndRewrite(op, rewriter); 65 if (failed(result)) 66 return failure(); 67 return cast<LinalgOp>(result->getOperation()); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // DecomposeOp 72 //===----------------------------------------------------------------------===// 73 74 DiagnosedSilenceableFailure 75 transform::DecomposeOp::applyToOne(linalg::LinalgOp target, 76 SmallVectorImpl<Operation *> &results, 77 transform::TransformState &state) { 78 FailureOr<LinalgOp> windowed = 79 tryApply<DownscaleSizeOneWindowed2DConvolution>(target); 80 if (succeeded(windowed)) { 81 results.push_back(*windowed); 82 return DiagnosedSilenceableFailure(success()); 83 } 84 FailureOr<LinalgOp> depthwise = 85 tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target); 86 if (succeeded(depthwise)) { 87 results.push_back(*depthwise); 88 return DiagnosedSilenceableFailure(success()); 89 } 90 results.assign(1, nullptr); 91 return emitDefaultSilenceableFailure(target); 92 } 93 94 //===----------------------------------------------------------------------===// 95 // FuseOp 96 //===----------------------------------------------------------------------===// 97 98 /// Apply a tiling transformation to all payload ops and store both the 99 /// tiled operation as well as the created tile loops. 100 static LogicalResult 101 applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps, 102 unsigned numLoops, 103 transform::TransformResults &transformResults, 104 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) { 105 SmallVector<Operation *> tiledLinalgOps; 106 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 107 for (unsigned int i = 0; i < numLoops; ++i) 108 loopOps[i].reserve(payloadOps.size()); 109 110 for (Operation *target : payloadOps) { 111 auto linalgOp = dyn_cast<linalg::LinalgOp>(target); 112 if (!linalgOp) 113 return transformOp->emitError("only LinalgOps are supported"); 114 115 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp); 116 if (failed(tiled)) 117 return failure(); 118 119 tiledLinalgOps.push_back(tiled->op); 120 if (tiled->loops.size() != numLoops) 121 // Not enough loops were generated. This usually means that the input size 122 // was smaller than the tiling size. 123 // TODO: LinalgTilingPattern should return failure(). 124 return failure(); 125 for (unsigned int i = 0; i < numLoops; ++i) 126 loopOps[i].push_back(tiled->loops[i]); 127 } 128 129 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 130 for (unsigned int i = 0; i < numLoops; ++i) 131 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 132 return success(); 133 } 134 135 /// Parse a tiling-like operation that returns the tiled op as well as the 136 /// created tile loops. The function counts the non-zero tile sizes to compute 137 /// the number of results. 138 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, 139 StringRef sizesAttrName) { 140 OpAsmParser::UnresolvedOperand targetOperand; 141 SMLoc opLoc = parser.getCurrentLocation(); 142 if (parser.parseOperand(targetOperand) || 143 parser.parseOptionalAttrDict(result.attributes)) 144 return failure(); 145 Attribute sizesAttr = result.attributes.get(sizesAttrName); 146 if (!sizesAttr) 147 return parser.emitError(opLoc) 148 << "expected '" << sizesAttrName << "' attribute"; 149 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); 150 if (!sizesArrayAttr) 151 return parser.emitError(opLoc) 152 << "'" << sizesAttrName << "' attribute must be an array"; 153 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); 154 size_t numExpectedLoops = 155 sizesArrayAttr.size() - 156 llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0); 157 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); 158 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) 159 return failure(); 160 return success(); 161 } 162 163 DiagnosedSilenceableFailure 164 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, 165 mlir::transform::TransformState &state) { 166 LinalgTilingAndFusionOptions fusionOptions; 167 fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes()); 168 fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange()); 169 170 LogicalResult result = applyTilingToAll( 171 getOperation(), state.getPayloadOps(getTarget()), 172 fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0), 173 transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> { 174 LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); 175 SimpleRewriter rewriter(getContext()); 176 rewriter.setInsertionPoint(linalgOp); 177 FailureOr<TileLoopNest> tileLoopNest = 178 pattern.returningMatchAndRewrite(linalgOp, rewriter); 179 if (failed(tileLoopNest)) 180 return failure(); 181 182 TiledLinalgOp tiledLinalgOp; 183 tiledLinalgOp.op = tileLoopNest->getRootOp(); 184 tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(), 185 tileLoopNest->getLoopOps().end()}; 186 return tiledLinalgOp; 187 }); 188 return DiagnosedSilenceableFailure(result); 189 } 190 191 ParseResult transform::FuseOp::parse(OpAsmParser &parser, 192 OperationState &result) { 193 return parseTileLikeOp( 194 parser, result, 195 transform::FuseOp::getTileSizesAttrName(result.name).getValue()); 196 } 197 198 void transform::FuseOp::print(OpAsmPrinter &p) { 199 p << ' '; 200 p << getTarget(); 201 p.printOptionalAttrDict((*this)->getAttrs()); 202 } 203 204 LogicalResult transform::FuseOp::verify() { 205 SmallVector<int64_t> permutation = 206 extractFromI64ArrayAttr(getTileInterchange()); 207 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); 208 if (!std::is_permutation(sequence.begin(), sequence.end(), 209 permutation.begin(), permutation.end())) { 210 return emitOpError() << "expects interchange to be a permutation, found " 211 << getTileInterchange(); 212 } 213 return success(); 214 } 215 216 //===----------------------------------------------------------------------===// 217 // FuseIntoContainingOp 218 //===----------------------------------------------------------------------===// 219 220 static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp, 221 Operation *containingOp, 222 RewriterBase &rewriter) { 223 auto tileableProducer = dyn_cast<TilingInterface>(producerOp); 224 if (!tileableProducer) 225 return failure(); 226 227 // Search the producer slices accessed within the containing operation. 228 // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe 229 // evolve into an interface. 230 SmallVector<tensor::ExtractSliceOp> sliceOps; 231 for (Operation *user : tileableProducer->getUsers()) { 232 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user); 233 if (!sliceOp) 234 continue; 235 if (!containingOp->isProperAncestor(sliceOp)) 236 continue; 237 sliceOps.push_back(sliceOp); 238 } 239 240 // Check for a non-empty list of fusion opportunities. 241 if (sliceOps.empty()) 242 return failure(); 243 244 SmallVector<Value> destinationOperands = 245 tileableProducer.getDestinationOperands(rewriter); 246 247 // Try to fuse the producer in-place. 248 SmallVector<Operation *> fusedOps; 249 for (tensor::ExtractSliceOp sliceOp : sliceOps) { 250 OpBuilder::InsertionGuard guard(rewriter); 251 rewriter.setInsertionPoint(sliceOp); 252 253 // Tile the producer. 254 FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue( 255 rewriter, /*resultNumber=*/0, destinationOperands, 256 sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), true); 257 if (failed(tiledProducer)) 258 return failure(); 259 fusedOps.push_back(tiledProducer->getDefiningOp()); 260 } 261 262 // Replace the extract op. 263 for (const auto &en : enumerate(sliceOps)) 264 rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0)); 265 return fusedOps; 266 } 267 268 static FailureOr<SmallVector<Operation *>> 269 cloneAndFuse(Operation *producerOp, Operation *containingOp, 270 RewriterBase &rewriter) { 271 // Gather all uses inside the containing op. 272 SmallVector<OpOperand *> uses; 273 for (OpResult result : producerOp->getOpResults()) 274 for (OpOperand &use : result.getUses()) 275 if (containingOp->isProperAncestor(use.getOwner())) 276 uses.push_back(&use); 277 278 // Check for a non-empty list of fusion opportunities. 279 if (uses.empty()) 280 return failure(); 281 282 // Clone and fuse inside the containing op. 283 SmallVector<Operation *> fusedOps; 284 for (OpOperand *use : uses) { 285 unsigned resultNumber = use->get().cast<OpResult>().getResultNumber(); 286 OpBuilder::InsertionGuard guard(rewriter); 287 rewriter.setInsertionPoint(use->getOwner()); 288 Operation *cloned = rewriter.clone(*producerOp); 289 rewriter.updateRootInPlace( 290 use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); }); 291 fusedOps.push_back(cloned); 292 } 293 294 return fusedOps; 295 } 296 297 DiagnosedSilenceableFailure 298 transform::FuseIntoContainingOp::apply(transform::TransformResults &results, 299 transform::TransformState &state) { 300 SmallVector<Operation *> fusedOps; 301 ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp()); 302 for (Operation *producerOp : producerOps) { 303 if (producerOp->getNumResults() != 1) { 304 Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note); 305 diag << "op with != 1 results not supported"; 306 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); 307 } 308 } 309 ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp()); 310 if (containingOps.size() != 1) 311 return DiagnosedSilenceableFailure( 312 this->emitOpError("requires exactly one containing_op handle")); 313 Operation *containingOp = containingOps.front(); 314 315 // Helper function to find the next producer that should be fused. Take any 316 // producer that has a use inside the containing op. 317 SmallVector<Operation *> remainingProducers(producerOps.begin(), 318 producerOps.end()); 319 auto getNextProducer = [&]() -> FailureOr<Operation *> { 320 for (const auto &it : enumerate(remainingProducers)) { 321 Operation *producerOp = it.value(); 322 bool hasUseInContainingOp = 323 any_of(producerOp->getUsers(), [&](Operation *op) { 324 return containingOp->isProperAncestor(op); 325 }); 326 // TODO: When resolving the TODO below (no duplicate ops), take an op that 327 // has no use among the remaining producers. This is a topological 328 // sorting. 329 if (hasUseInContainingOp) { 330 remainingProducers.erase(remainingProducers.begin() + it.index()); 331 return producerOp; 332 } 333 } 334 return failure(); 335 }; 336 337 IRRewriter rewriter(getContext()); 338 while (!remainingProducers.empty()) { 339 auto nextProducer = getNextProducer(); 340 if (failed(nextProducer)) { 341 Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note); 342 diag << "could not fuse ops into container"; 343 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); 344 } 345 346 Operation *producerOp = *nextProducer; 347 // TODO: If there are multiple uses of the producer in the containing op, we 348 // currently tile/clone the op multiple times (once per use). In some cases, 349 // we can tile/clone once and reuse the value for each use. Futhermore, 350 // producers should then be traversed according to a topological sorting. 351 auto tiled = tileAndFuse(producerOp, containingOp, rewriter); 352 if (succeeded(tiled)) 353 fusedOps.append(*tiled); 354 355 auto cloned = cloneAndFuse(producerOp, containingOp, rewriter); 356 if (succeeded(cloned)) 357 fusedOps.append(*cloned); 358 359 if (failed(tiled) && failed(cloned)) { 360 Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note); 361 diag << "could not fuse into containing op"; 362 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); 363 } 364 } 365 366 results.set(getFusedOp().cast<OpResult>(), fusedOps); 367 return DiagnosedSilenceableFailure::success(); 368 } 369 370 //===----------------------------------------------------------------------===// 371 // GeneralizeOp 372 //===----------------------------------------------------------------------===// 373 374 DiagnosedSilenceableFailure 375 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, 376 SmallVectorImpl<Operation *> &results, 377 transform::TransformState &state) { 378 // Exit early if no transformation is needed. 379 if (isa<GenericOp>(target)) { 380 results.push_back(target); 381 return DiagnosedSilenceableFailure(success()); 382 } 383 FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target); 384 if (succeeded(generic)) { 385 results.push_back(generic->getOperation()); 386 return DiagnosedSilenceableFailure(success()); 387 } 388 results.assign(1, nullptr); 389 return emitDefaultSilenceableFailure(target); 390 } 391 392 //===----------------------------------------------------------------------===// 393 // InterchangeOp 394 //===----------------------------------------------------------------------===// 395 396 DiagnosedSilenceableFailure 397 transform::InterchangeOp::applyToOne(linalg::GenericOp target, 398 SmallVectorImpl<Operation *> &results, 399 transform::TransformState &state) { 400 SmallVector<unsigned> interchangeVector = 401 extractUIntArray(getIteratorInterchange()); 402 // Exit early if no transformation is needed. 403 if (interchangeVector.empty()) { 404 results.push_back(target); 405 return DiagnosedSilenceableFailure(success()); 406 } 407 SimpleRewriter rewriter(target->getContext()); 408 FailureOr<GenericOp> res = 409 interchangeGenericOp(rewriter, target, interchangeVector); 410 if (failed(res)) 411 return DiagnosedSilenceableFailure::definiteFailure(); 412 results.push_back(res->getOperation()); 413 return DiagnosedSilenceableFailure(success()); 414 } 415 416 LogicalResult transform::InterchangeOp::verify() { 417 SmallVector<unsigned> permutation = 418 extractUIntArray(getIteratorInterchange()); 419 auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size())); 420 if (!std::is_permutation(sequence.begin(), sequence.end(), 421 permutation.begin(), permutation.end())) { 422 return emitOpError() 423 << "expects iterator_interchange to be a permutation, found " 424 << getIteratorInterchange(); 425 } 426 return success(); 427 } 428 429 //===---------------------------------------------------------------------===// 430 // MatchOp 431 //===---------------------------------------------------------------------===// 432 433 DiagnosedSilenceableFailure 434 transform::MatchOp::apply(transform::TransformResults &results, 435 transform::TransformState &state) { 436 llvm::StringSet<> strs; 437 if (getOps().hasValue()) 438 strs.insert(getOps()->getAsValueRange<StringAttr>().begin(), 439 getOps()->getAsValueRange<StringAttr>().end()); 440 441 ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); 442 if (payloadOps.size() != 1) 443 return DiagnosedSilenceableFailure( 444 this->emitOpError("requires exactly one target handle")); 445 446 SmallVector<Operation *> res; 447 auto matchFun = [&](Operation *op) { 448 if (getOps().hasValue() && !strs.contains(op->getName().getStringRef())) 449 return WalkResult::advance(); 450 451 // Interfaces cannot be matched by name, just by ID. 452 // So we specifically encode the interfaces we care about for this op. 453 if (getInterface().hasValue()) { 454 auto iface = getInterface().getValue(); 455 if (iface == transform::MatchInterfaceEnum::LinalgOp && 456 !isa<linalg::LinalgOp>(op)) 457 return WalkResult::advance(); 458 if (iface == transform::MatchInterfaceEnum::TilingInterface && 459 isa<TilingInterface>(op)) 460 return WalkResult::advance(); 461 } 462 463 if (getAttribute().hasValue() && !op->hasAttr(getAttribute().getValue())) 464 return WalkResult::advance(); 465 466 // All constraints are satisfied. 467 res.push_back(op); 468 return WalkResult::advance(); 469 }; 470 471 payloadOps.front()->walk(matchFun); 472 results.set(getResult().cast<OpResult>(), res); 473 return DiagnosedSilenceableFailure(success()); 474 } 475 476 //===---------------------------------------------------------------------===// 477 // MultiTileSizesOp 478 //===---------------------------------------------------------------------===// 479 480 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( 481 LinalgOp target, SmallVector<Operation *> &results, TransformState &state) { 482 OpBuilder builder(target.getContext()); 483 builder.setInsertionPoint(target); 484 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); 485 OpFoldResult divisor = builder.getIndexAttr(getDivisor()); 486 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes( 487 builder, target, getDimension(), targetSize, divisor); 488 if (failed(spec)) { 489 return emitSilenceableError() << "could not generate tile size computation"; 490 } 491 492 AffineExpr s0 = builder.getAffineSymbolExpr(0); 493 AffineExpr s1 = builder.getAffineSymbolExpr(1); 494 Operation *splitPoint = 495 makeComposedAffineApply(builder, target.getLoc(), s0 * s1, 496 {spec->lowTileSize, spec->lowTripCount}); 497 Operation *lowTileSize = spec->lowTileSize.getDefiningOp(); 498 Operation *highTileSize = spec->highTileSize.getDefiningOp(); 499 assert(lowTileSize && highTileSize && splitPoint && 500 "tile sizes are not produced by operations"); 501 results.reserve(results.size() + 3); 502 results.push_back(lowTileSize); 503 results.push_back(highTileSize); 504 results.push_back(splitPoint); 505 return DiagnosedSilenceableFailure::success(); 506 } 507 508 void transform::MultiTileSizesOp::getEffects( 509 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 510 onlyReadsHandle(getTarget(), effects); 511 producesHandle(getResults(), effects); 512 modifiesPayload(effects); 513 } 514 515 //===---------------------------------------------------------------------===// 516 // PadOp 517 //===---------------------------------------------------------------------===// 518 519 DiagnosedSilenceableFailure 520 transform::PadOp::applyToOne(linalg::LinalgOp target, 521 SmallVectorImpl<Operation *> &results, 522 transform::TransformState &state) { 523 // Convert the integer packing flags to booleans. 524 SmallVector<bool> packPaddings; 525 for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings())) 526 packPaddings.push_back(static_cast<bool>(packPadding)); 527 528 // Convert the padding values to attributes. 529 SmallVector<Attribute> paddingValues; 530 for (auto const &it : 531 llvm::zip(getPaddingValues(), target->getOperandTypes())) { 532 Attribute attr = std::get<0>(it); 533 Type elementType = getElementTypeOrSelf(std::get<1>(it)); 534 // Try to parse string attributes to obtain an attribute of element type. 535 if (auto stringAttr = attr.dyn_cast<StringAttr>()) { 536 paddingValues.push_back( 537 parseAttribute(attr.cast<StringAttr>(), elementType)); 538 if (!paddingValues.back()) { 539 auto diag = this->emitOpError("expects a padding that parses to ") 540 << elementType << ", got " << std::get<0>(it); 541 diag.attachNote(target.getLoc()) << "when applied to this op"; 542 return DiagnosedSilenceableFailure::definiteFailure(); 543 } 544 continue; 545 } 546 // Otherwise, add the attribute directly. 547 if (attr.getType() != elementType) { 548 auto diag = this->emitOpError("expects a padding value of type ") 549 << elementType << ", got " << attr; 550 diag.attachNote(target.getLoc()) << "when applied to this op"; 551 return DiagnosedSilenceableFailure::definiteFailure(); 552 } 553 paddingValues.push_back(attr); 554 } 555 556 // Extract the transpose vectors. 557 SmallVector<SmallVector<int64_t>> transposePaddings; 558 for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>()) 559 transposePaddings.push_back( 560 extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>())); 561 562 LinalgPaddingOptions paddingOptions; 563 paddingOptions.setPaddingValues(paddingValues); 564 paddingOptions.setPaddingDimensions( 565 extractFromI64ArrayAttr(getPaddingDimensions())); 566 paddingOptions.setPackPaddings(packPaddings); 567 paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings())); 568 paddingOptions.setTransposePaddings(transposePaddings); 569 570 FailureOr<LinalgOp> result = 571 tryApply<LinalgPaddingPattern>(target, paddingOptions); 572 if (succeeded(result)) { 573 results.push_back(result->getOperation()); 574 return DiagnosedSilenceableFailure(success()); 575 } 576 577 results.assign(1, nullptr); 578 return emitDefaultSilenceableFailure(target); 579 } 580 581 LogicalResult transform::PadOp::verify() { 582 SmallVector<int64_t> packPaddings = 583 extractFromI64ArrayAttr(getPackPaddings()); 584 if (any_of(packPaddings, [](int64_t packPadding) { 585 return packPadding != 0 && packPadding != 1; 586 })) { 587 return emitOpError() 588 << "expects pack_paddings to contain booleans (0/1), found " 589 << getPackPaddings(); 590 } 591 592 SmallVector<int64_t> paddingDimensions = 593 extractFromI64ArrayAttr(getPaddingDimensions()); 594 if (any_of(paddingDimensions, 595 [](int64_t paddingDimension) { return paddingDimension < 0; })) { 596 return emitOpError() 597 << "expects padding_dimensions to contain positive integers, found " 598 << getPaddingDimensions(); 599 } 600 601 SmallVector<int64_t> hoistPaddings = 602 extractFromI64ArrayAttr(getHoistPaddings()); 603 if (any_of(hoistPaddings, 604 [](int64_t hoistPadding) { return hoistPadding < 0; })) { 605 return emitOpError() 606 << "expects hoist_paddings to contain positive integers, found " 607 << getHoistPaddings(); 608 } 609 610 ArrayAttr transposes = getTransposePaddings(); 611 for (Attribute attr : transposes) { 612 SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr); 613 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); 614 if (!std::is_permutation(sequence.begin(), sequence.end(), 615 transpose.begin(), transpose.end())) { 616 return emitOpError() 617 << "expects transpose_paddings to be a permutation, found " 618 << attr; 619 } 620 } 621 return success(); 622 } 623 624 //===----------------------------------------------------------------------===// 625 // PromoteOp 626 //===----------------------------------------------------------------------===// 627 628 DiagnosedSilenceableFailure 629 transform::PromoteOp::applyToOne(linalg::LinalgOp target, 630 SmallVectorImpl<Operation *> &results, 631 transform::TransformState &state) { 632 LinalgPromotionOptions promotionOptions; 633 if (!getOperandsToPromote().empty()) 634 promotionOptions = promotionOptions.setOperandsToPromote( 635 extractFromI64ArrayAttr(getOperandsToPromote())); 636 if (getUseFullTilesByDefault()) 637 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( 638 getUseFullTilesByDefault()); 639 if (getUseAlloca()) 640 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca()); 641 if (!getUseFullTileBuffers().empty()) 642 promotionOptions = promotionOptions.setUseFullTileBuffers( 643 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>())); 644 if (getAlignment().has_value()) 645 promotionOptions = promotionOptions.setAlignment(*getAlignment()); 646 647 if (failed(promoteSubviewsPrecondition(target, promotionOptions))) 648 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 649 650 SimpleRewriter rewriter(target->getContext()); 651 rewriter.setInsertionPoint(target); 652 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions); 653 if (failed(res)) 654 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 655 results.push_back(target); 656 return DiagnosedSilenceableFailure(success()); 657 } 658 659 //===----------------------------------------------------------------------===// 660 // ScalarizeOp 661 //===----------------------------------------------------------------------===// 662 663 DiagnosedSilenceableFailure 664 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, 665 SmallVectorImpl<Operation *> &results, 666 transform::TransformState &state) { 667 LinalgTilingOptions tilingOptions; 668 tilingOptions.scalarizeDynamicDims(); 669 // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile 670 // sizes and asserts that it is not already set. 671 SmallVector<int64_t> emptyTileSizes; 672 LinalgTilingPattern pattern(getContext(), tilingOptions); 673 SimpleRewriter rewriter(getContext()); 674 rewriter.setInsertionPoint(target); 675 FailureOr<TiledLinalgOp> result = 676 pattern.returningMatchAndRewrite(target, rewriter); 677 if (failed(result)) 678 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 679 680 results.push_back(result->op); 681 return DiagnosedSilenceableFailure(success()); 682 } 683 684 //===----------------------------------------------------------------------===// 685 // SplitOp 686 //===----------------------------------------------------------------------===// 687 688 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, 689 TransformState &state) { 690 // Collect the dynamic split points if provided. 691 ArrayRef<Operation *> payload = state.getPayloadOps(getTarget()); 692 SimpleRewriter rewriter(getContext()); 693 SmallVector<OpFoldResult> splitPoints; 694 splitPoints.reserve(payload.size()); 695 if (getDynamicSplitPoint()) { 696 auto diag = DiagnosedSilenceableFailure::success(); 697 splitPoints = llvm::to_vector(llvm::map_range( 698 state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { 699 if (op->getNumResults() != 1 || 700 !op->getResult(0).getType().isIndex()) { 701 diag = emitSilenceableError() 702 << "expected dynamic split point handle to point to a " 703 "single-result index-typed op"; 704 diag.attachNote(op->getLoc()) << "dynamic split point"; 705 } 706 return OpFoldResult(op->getResult(0)); 707 })); 708 if (!diag.succeeded()) 709 return diag; 710 711 if (splitPoints.size() != payload.size()) { 712 emitError() << "expected the dynamic split point handle to point to as " 713 "many operations (" 714 << splitPoints.size() << ") as the target handle (" 715 << payload.size() << ")"; 716 return DiagnosedSilenceableFailure::definiteFailure(); 717 } 718 } else { 719 splitPoints.resize(payload.size(), 720 rewriter.getIndexAttr(getStaticSplitPoint())); 721 } 722 723 // Split each target operation. 724 SmallVector<Operation *> first, second; 725 for (const auto &pair : llvm::zip(payload, splitPoints)) { 726 Operation *target = std::get<0>(pair); 727 auto linalgOp = dyn_cast<LinalgOp>(target); 728 if (!linalgOp) { 729 auto diag = emitSilenceableError() << "only applies to structured ops"; 730 diag.attachNote(target->getLoc()) << "target op"; 731 return diag; 732 } 733 734 if (getDimension() >= linalgOp.getNumLoops()) { 735 auto diag = emitSilenceableError() << "dimension " << getDimension() 736 << " does not exist in target op"; 737 diag.attachNote(target->getLoc()) << "target op"; 738 return diag; 739 } 740 741 rewriter.setInsertionPoint(linalgOp); 742 std::tie(first.emplace_back(), second.emplace_back()) = 743 linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair)); 744 } 745 746 results.set(getFirst().cast<OpResult>(), first); 747 results.set(getSecond().cast<OpResult>(), second); 748 return DiagnosedSilenceableFailure::success(); 749 } 750 751 void SplitOp::getEffects( 752 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 753 consumesHandle(getTarget(), effects); 754 if (getDynamicSplitPoint()) 755 onlyReadsHandle(getDynamicSplitPoint(), effects); 756 producesHandle(getResults(), effects); 757 modifiesPayload(effects); 758 } 759 760 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { 761 OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; 762 IntegerAttr staticSplitPoint; 763 auto pdlOperationType = 764 pdl::OperationType::get(parser.getBuilder().getContext()); 765 if (parser.parseOperand(target) || 766 parser.resolveOperand(target, pdlOperationType, result.operands) || 767 parser.parseKeyword("after")) 768 return failure(); 769 770 OptionalParseResult dynamicPointParseResult = 771 parser.parseOptionalOperand(dynamicSplitPoint); 772 if (!dynamicPointParseResult.hasValue()) { 773 int64_t staticSplitPointValue; 774 if (failed(parser.parseInteger(staticSplitPointValue))) 775 return failure(); 776 777 staticSplitPoint = 778 parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); 779 } else { 780 if (failed(*dynamicPointParseResult) || 781 parser.resolveOperand(dynamicSplitPoint, pdlOperationType, 782 result.operands)) { 783 return failure(); 784 } 785 786 staticSplitPoint = 787 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize); 788 } 789 790 result.addAttribute( 791 SplitOp::getStaticSplitPointAttrName(result.name).getValue(), 792 staticSplitPoint); 793 if (failed(parser.parseOptionalAttrDict(result.attributes))) 794 return failure(); 795 796 result.addTypes({pdlOperationType, pdlOperationType}); 797 return success(); 798 } 799 800 void SplitOp::print(OpAsmPrinter &printer) { 801 printer << " " << getTarget() << " after "; 802 int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint()); 803 if (staticSplitSize != ShapedType::kDynamicSize) 804 printer << staticSplitSize; 805 else 806 printer << getDynamicSplitPoint(); 807 printer << " "; 808 printer.printOptionalAttrDict(getOperation()->getAttrs(), 809 {getStaticSplitPointAttrName()}); 810 } 811 812 LogicalResult SplitOp::verify() { 813 if ((static_cast<int64_t>(getStaticSplitPoint()) != 814 ShapedType::kDynamicSize) ^ 815 (getDynamicSplitPoint() == nullptr)) { 816 return emitOpError() 817 << "expects either a dynamic or a static split point to be provided"; 818 } 819 return success(); 820 } 821 822 //===----------------------------------------------------------------------===// 823 // SplitReductionOp 824 //===----------------------------------------------------------------------===// 825 826 DiagnosedSilenceableFailure 827 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, 828 SmallVectorImpl<Operation *> &results, 829 transform::TransformState &state) { 830 ControlSplitReductionFn splitFn = [&](LinalgOp) { 831 return std::pair<int64_t, unsigned>(getSplitFactor(), 832 getInsertSplitDimension()); 833 }; 834 SimpleRewriter rewriter(getContext()); 835 rewriter.setInsertionPoint(target); 836 FailureOr<SplitReductionResult> splitResult = 837 (getUseScalingAlgorithm()) 838 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) 839 : splitReduction(rewriter, target, splitFn, getUseAlloc()); 840 if (failed(splitResult)) 841 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 842 843 results.push_back(splitResult->initOrAlloc); 844 results.push_back(splitResult->fillOp); 845 results.push_back(splitResult->splitLinalgOp); 846 results.push_back(splitResult->resultCombiningLinalgOp); 847 return DiagnosedSilenceableFailure(success()); 848 } 849 850 //===----------------------------------------------------------------------===// 851 // TileOp 852 //===----------------------------------------------------------------------===// 853 854 DiagnosedSilenceableFailure 855 transform::TileOp::apply(TransformResults &transformResults, 856 TransformState &state) { 857 LinalgTilingOptions tilingOptions; 858 SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes()); 859 860 ArrayRef<Operation *> targets = state.getPayloadOps(getTarget()); 861 SmallVector<ArrayRef<Operation *>> dynamicSizeProducers; 862 dynamicSizeProducers.reserve(getDynamicSizes().size()); 863 for (Value dynamicSizeProducerHandle : getDynamicSizes()) { 864 dynamicSizeProducers.push_back( 865 state.getPayloadOps(dynamicSizeProducerHandle)); 866 867 if (dynamicSizeProducers.back().size() != targets.size()) { 868 DiagnosedSilenceableFailure diag = 869 emitSilenceableError() 870 << "expected as many dynamic size-producing operations (" 871 << dynamicSizeProducers.back().size() << ") as target ops (" 872 << targets.size() << ")"; 873 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; 874 return diag; 875 } 876 877 for (Operation *op : dynamicSizeProducers.back()) { 878 if (op->getNumResults() == 1 && 879 op->getResult(0).getType().isa<IndexType>()) 880 continue; 881 DiagnosedSilenceableFailure diag = 882 emitSilenceableError() << "expected sizes to be produced by ops " 883 "with a single index-type result"; 884 diag.attachNote(op->getLoc()) << "size producer op"; 885 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; 886 return diag; 887 } 888 } 889 890 SmallVector<Operation *> tiled; 891 SmallVector<SmallVector<Operation *, 4>, 4> loops; 892 loops.resize(getLoops().size()); 893 for (auto &en : llvm::enumerate(targets)) { 894 auto linalgOp = dyn_cast<LinalgOp>(en.value()); 895 if (!linalgOp) { 896 DiagnosedSilenceableFailure diag = emitSilenceableError() 897 << "only linalg ops are supported"; 898 diag.attachNote(en.value()->getLoc()) << "target op"; 899 return diag; 900 } 901 902 unsigned index = en.index(); 903 if (!tileSizes.empty()) { 904 tilingOptions.setTileSizeComputationFunction( 905 [&, index](OpBuilder &b, Operation *) { 906 SmallVector<Value, 4> sizes; 907 sizes.reserve(tileSizes.size()); 908 unsigned dynamicIdx = 0; 909 for (OpFoldResult ofr : getMixedSizes()) { 910 if (auto attr = ofr.dyn_cast<Attribute>()) { 911 sizes.push_back(b.create<arith::ConstantIndexOp>( 912 getLoc(), attr.cast<IntegerAttr>().getInt())); 913 } else { 914 sizes.push_back( 915 dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); 916 } 917 } 918 return sizes; 919 }); 920 } 921 922 tilingOptions.setInterchange(extractUIntArray(getInterchange())); 923 LinalgTilingPattern pattern(getContext(), tilingOptions); 924 SimpleRewriter rewriter(linalgOp.getContext()); 925 FailureOr<TiledLinalgOp> tiledOp = 926 pattern.returningMatchAndRewrite(linalgOp, rewriter); 927 if (failed(tiledOp)) 928 return DiagnosedSilenceableFailure::definiteFailure(); 929 930 tiled.push_back(tiledOp->op); 931 for (const auto &en2 : llvm::enumerate(tiledOp->loops)) 932 loops[en2.index()].push_back(en2.value()); 933 } 934 935 transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled); 936 for (const auto &en : llvm::enumerate(loops)) 937 transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value()); 938 939 return DiagnosedSilenceableFailure::success(); 940 } 941 942 SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() { 943 ValueRange dynamic = getDynamicSizes(); 944 SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes()); 945 SmallVector<OpFoldResult> results; 946 results.reserve(tileSizes.size()); 947 unsigned dynamicPos = 0; 948 Builder builder(getContext()); 949 for (int64_t size : tileSizes) { 950 if (size == ShapedType::kDynamicSize) { 951 results.push_back(dynamic[dynamicPos++]); 952 } else { 953 results.push_back(builder.getIndexAttr(size)); 954 } 955 } 956 return results; 957 } 958 959 ParseResult transform::TileOp::parse(OpAsmParser &parser, 960 OperationState &result) { 961 OpAsmParser::UnresolvedOperand target; 962 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes; 963 ArrayAttr staticSizes; 964 auto pdlOperationType = pdl::OperationType::get(parser.getContext()); 965 if (parser.parseOperand(target) || 966 parser.resolveOperand(target, pdlOperationType, result.operands) || 967 parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) || 968 parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || 969 parser.parseOptionalAttrDict(result.attributes)) 970 return ParseResult::failure(); 971 972 result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); 973 size_t numExpectedLoops = 974 staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0); 975 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType)); 976 return success(); 977 } 978 979 void TileOp::print(OpAsmPrinter &p) { 980 p << ' ' << getTarget(); 981 printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(), 982 getStaticSizes()); 983 p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); 984 } 985 986 void transform::TileOp::getEffects( 987 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 988 consumesHandle(getTarget(), effects); 989 onlyReadsHandle(getDynamicSizes(), effects); 990 producesHandle(getTiledLinalgOp(), effects); 991 producesHandle(getLoops(), effects); 992 modifiesPayload(effects); 993 } 994 995 //===----------------------------------------------------------------------===// 996 // TileToForeachThreadOp 997 //===----------------------------------------------------------------------===// 998 999 DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne( 1000 TilingInterface target, SmallVectorImpl<Operation *> &results, 1001 transform::TransformState &state) { 1002 IRRewriter rewriter(getContext()); 1003 rewriter.setInsertionPoint(target); 1004 auto maybeThreadDimMappingAttr = getThreadDimMapping(); 1005 auto dimMapping = 1006 llvm::to_vector(maybeThreadDimMappingAttr 1007 ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) 1008 : ArrayRef<int64_t>{}); 1009 1010 FailureOr<ForeachThreadTilingResult> tilingResult = failure(); 1011 if (Optional<ArrayAttr> numThreads = getNumThreads()) 1012 tilingResult = linalg::tileToForeachThreadOp( 1013 rewriter, target, getAsOpFoldResult(*numThreads), dimMapping); 1014 1015 if (Optional<ArrayAttr> tileSizes = getTileSizes()) 1016 tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( 1017 rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping); 1018 1019 if (failed(tilingResult)) 1020 return emitDefaultSilenceableFailure(target); 1021 rewriter.replaceOp(target, tilingResult->tileOp->getResults()); 1022 results.assign({tilingResult->tileOp, tilingResult->tiledOp}); 1023 return DiagnosedSilenceableFailure(success()); 1024 } 1025 1026 //===----------------------------------------------------------------------===// 1027 // VectorizeOp 1028 //===----------------------------------------------------------------------===// 1029 1030 DiagnosedSilenceableFailure 1031 transform::VectorizeOp::applyToOne(Operation *target, 1032 SmallVectorImpl<Operation *> &results, 1033 transform::TransformState &state) { 1034 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1035 auto diag = this->emitOpError("requires isolated-from-above targets"); 1036 diag.attachNote(target->getLoc()) << "non-isolated target"; 1037 return DiagnosedSilenceableFailure::definiteFailure(); 1038 } 1039 1040 MLIRContext *ctx = getContext(); 1041 RewritePatternSet patterns(ctx); 1042 patterns.add<LinalgVectorizationPattern>(ctx); 1043 1044 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 1045 vector::populateVectorReductionToContractPatterns(patterns); 1046 patterns.add<linalg::LinalgCopyVTRForwardingPattern, 1047 linalg::LinalgCopyVTWForwardingPattern>(ctx, 1048 /*benefit=*/2); 1049 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); 1050 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); 1051 if (getVectorizePadding()) 1052 linalg::populatePadOpVectorizationPatterns(patterns); 1053 1054 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) 1055 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 1056 1057 results.push_back(target); 1058 return DiagnosedSilenceableFailure(success()); 1059 } 1060 1061 //===----------------------------------------------------------------------===// 1062 // Transform op registration 1063 //===----------------------------------------------------------------------===// 1064 1065 namespace { 1066 /// Registers new ops and declares PDL as dependent dialect since the additional 1067 /// ops are using PDL types for operands and results. 1068 class LinalgTransformDialectExtension 1069 : public transform::TransformDialectExtension< 1070 LinalgTransformDialectExtension> { 1071 public: 1072 LinalgTransformDialectExtension() { 1073 declareDependentDialect<AffineDialect>(); 1074 declareDependentDialect<arith::ArithmeticDialect>(); 1075 declareDependentDialect<pdl::PDLDialect>(); 1076 declareDependentDialect<scf::SCFDialect>(); 1077 declareDependentDialect<vector::VectorDialect>(); 1078 registerTransformOps< 1079 #define GET_OP_LIST 1080 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 1081 >(); 1082 } 1083 }; 1084 } // namespace 1085 1086 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" 1087 1088 #define GET_OP_CLASSES 1089 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 1090 1091 void mlir::linalg::registerTransformDialectExtension( 1092 DialectRegistry ®istry) { 1093 registry.addExtensions<LinalgTransformDialectExtension>(); 1094 } 1095