1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 10 11 #include "mlir/AsmParser/AsmParser.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/PDL/IR/PDL.h" 17 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 18 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 19 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 20 #include "mlir/Interfaces/TilingInterface.h" 21 #include "mlir/Parser/Parser.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "llvm/ADT/StringSet.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 using namespace mlir::transform; 28 29 /// Extracts a vector of unsigned from an array attribute. Asserts if the 30 /// attribute contains values other than intergers. May truncate. 31 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) { 32 SmallVector<unsigned> result; 33 result.reserve(attr.size()); 34 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 35 result.push_back(value.getZExtValue()); 36 return result; 37 } 38 39 namespace { 40 /// A simple pattern rewriter that implements no special logic. 41 class SimpleRewriter : public PatternRewriter { 42 public: 43 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 44 }; 45 } // namespace 46 47 /// Attempts to apply the pattern specified as template argument to the given 48 /// operation. The pattern is expected to have a `returningMatchAndRewrite` 49 /// function that returns the "main" result or failure. Returns failure if the 50 /// pattern failed to apply. Extra arguments are forwarded to the pattern 51 /// constructor. 52 template <typename PatternTy, typename... Args> 53 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { 54 // Check if the given operation has the type expected by the pattern. 55 using OpTy = typename llvm::function_traits< 56 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; 57 auto op = dyn_cast<OpTy>(operation); 58 if (!op) 59 return failure(); 60 61 // Apply the pattern directly to the op. 62 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); 63 SimpleRewriter rewriter(operation->getContext()); 64 rewriter.setInsertionPoint(operation); 65 auto result = pattern.returningMatchAndRewrite(op, rewriter); 66 if (failed(result)) 67 return failure(); 68 return cast<LinalgOp>(result->getOperation()); 69 } 70 71 //===----------------------------------------------------------------------===// 72 // DecomposeOp 73 //===----------------------------------------------------------------------===// 74 75 DiagnosedSilenceableFailure 76 transform::DecomposeOp::applyToOne(linalg::LinalgOp target, 77 SmallVectorImpl<Operation *> &results, 78 transform::TransformState &state) { 79 FailureOr<LinalgOp> windowed = 80 tryApply<DownscaleSizeOneWindowed2DConvolution>(target); 81 if (succeeded(windowed)) { 82 results.push_back(*windowed); 83 return DiagnosedSilenceableFailure(success()); 84 } 85 FailureOr<LinalgOp> depthwise = 86 tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target); 87 if (succeeded(depthwise)) { 88 results.push_back(*depthwise); 89 return DiagnosedSilenceableFailure(success()); 90 } 91 results.assign(1, nullptr); 92 return emitDefaultSilenceableFailure(target); 93 } 94 95 //===----------------------------------------------------------------------===// 96 // FuseOp 97 //===----------------------------------------------------------------------===// 98 99 /// Apply a tiling transformation to all payload ops and store both the 100 /// tiled operation as well as the created tile loops. 101 static LogicalResult 102 applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps, 103 unsigned numLoops, 104 transform::TransformResults &transformResults, 105 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) { 106 SmallVector<Operation *> tiledLinalgOps; 107 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 108 for (unsigned int i = 0; i < numLoops; ++i) 109 loopOps[i].reserve(payloadOps.size()); 110 111 for (Operation *target : payloadOps) { 112 auto linalgOp = dyn_cast<linalg::LinalgOp>(target); 113 if (!linalgOp) 114 return transformOp->emitError("only LinalgOps are supported"); 115 116 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp); 117 if (failed(tiled)) 118 return failure(); 119 120 tiledLinalgOps.push_back(tiled->op); 121 if (tiled->loops.size() != numLoops) 122 // Not enough loops were generated. This usually means that the input size 123 // was smaller than the tiling size. 124 // TODO: LinalgTilingPattern should return failure(). 125 return failure(); 126 for (unsigned int i = 0; i < numLoops; ++i) 127 loopOps[i].push_back(tiled->loops[i]); 128 } 129 130 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 131 for (unsigned int i = 0; i < numLoops; ++i) 132 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 133 return success(); 134 } 135 136 /// Parse a tiling-like operation that returns the tiled op as well as the 137 /// created tile loops. The function counts the non-zero tile sizes to compute 138 /// the number of results. 139 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, 140 StringRef sizesAttrName) { 141 OpAsmParser::UnresolvedOperand targetOperand; 142 SMLoc opLoc = parser.getCurrentLocation(); 143 if (parser.parseOperand(targetOperand) || 144 parser.parseOptionalAttrDict(result.attributes)) 145 return failure(); 146 Attribute sizesAttr = result.attributes.get(sizesAttrName); 147 if (!sizesAttr) 148 return parser.emitError(opLoc) 149 << "expected '" << sizesAttrName << "' attribute"; 150 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); 151 if (!sizesArrayAttr) 152 return parser.emitError(opLoc) 153 << "'" << sizesAttrName << "' attribute must be an array"; 154 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); 155 size_t numExpectedLoops = 156 sizesArrayAttr.size() - 157 llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0); 158 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); 159 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) 160 return failure(); 161 return success(); 162 } 163 164 DiagnosedSilenceableFailure 165 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, 166 mlir::transform::TransformState &state) { 167 LinalgTilingAndFusionOptions fusionOptions; 168 fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes()); 169 fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange()); 170 171 LogicalResult result = applyTilingToAll( 172 getOperation(), state.getPayloadOps(getTarget()), 173 fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0), 174 transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> { 175 LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); 176 SimpleRewriter rewriter(getContext()); 177 rewriter.setInsertionPoint(linalgOp); 178 FailureOr<TileLoopNest> tileLoopNest = 179 pattern.returningMatchAndRewrite(linalgOp, rewriter); 180 if (failed(tileLoopNest)) 181 return failure(); 182 183 TiledLinalgOp tiledLinalgOp; 184 tiledLinalgOp.op = tileLoopNest->getRootOp(); 185 tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(), 186 tileLoopNest->getLoopOps().end()}; 187 return tiledLinalgOp; 188 }); 189 return DiagnosedSilenceableFailure(result); 190 } 191 192 ParseResult transform::FuseOp::parse(OpAsmParser &parser, 193 OperationState &result) { 194 return parseTileLikeOp( 195 parser, result, 196 transform::FuseOp::getTileSizesAttrName(result.name).getValue()); 197 } 198 199 void transform::FuseOp::print(OpAsmPrinter &p) { 200 p << ' '; 201 p << getTarget(); 202 p.printOptionalAttrDict((*this)->getAttrs()); 203 } 204 205 LogicalResult transform::FuseOp::verify() { 206 SmallVector<int64_t> permutation = 207 extractFromI64ArrayAttr(getTileInterchange()); 208 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); 209 if (!std::is_permutation(sequence.begin(), sequence.end(), 210 permutation.begin(), permutation.end())) { 211 return emitOpError() << "expects interchange to be a permutation, found " 212 << getTileInterchange(); 213 } 214 return success(); 215 } 216 217 //===----------------------------------------------------------------------===// 218 // FuseIntoContainingOp 219 //===----------------------------------------------------------------------===// 220 221 static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp, 222 Operation *containingOp, 223 RewriterBase &rewriter) { 224 auto tileableProducer = dyn_cast<TilingInterface>(producerOp); 225 if (!tileableProducer) 226 return failure(); 227 228 // Search the producer slices accessed within the containing operation. 229 // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe 230 // evolve into an interface. 231 SmallVector<tensor::ExtractSliceOp> sliceOps; 232 for (Operation *user : tileableProducer->getUsers()) { 233 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user); 234 if (!sliceOp) 235 continue; 236 if (!containingOp->isProperAncestor(sliceOp)) 237 continue; 238 sliceOps.push_back(sliceOp); 239 } 240 241 // Check for a non-empty list of fusion opportunities. 242 if (sliceOps.empty()) 243 return failure(); 244 245 SmallVector<Value> destinationOperands = 246 tileableProducer.getDestinationOperands(rewriter); 247 248 // Try to fuse the producer in-place. 249 SmallVector<Operation *> fusedOps; 250 for (tensor::ExtractSliceOp sliceOp : sliceOps) { 251 OpBuilder::InsertionGuard guard(rewriter); 252 rewriter.setInsertionPoint(sliceOp); 253 254 // Tile the producer. 255 FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue( 256 rewriter, /*resultNumber=*/0, destinationOperands, 257 sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), true); 258 if (failed(tiledProducer)) 259 return failure(); 260 fusedOps.push_back(tiledProducer->getDefiningOp()); 261 } 262 263 // Replace the extract op. 264 for (const auto &en : enumerate(sliceOps)) 265 rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0)); 266 return fusedOps; 267 } 268 269 static FailureOr<SmallVector<Operation *>> 270 cloneAndFuse(Operation *producerOp, Operation *containingOp, 271 RewriterBase &rewriter) { 272 // Gather all uses inside the containing op. 273 SmallVector<OpOperand *> uses; 274 for (OpResult result : producerOp->getOpResults()) 275 for (OpOperand &use : result.getUses()) 276 if (containingOp->isProperAncestor(use.getOwner())) 277 uses.push_back(&use); 278 279 // Check for a non-empty list of fusion opportunities. 280 if (uses.empty()) 281 return failure(); 282 283 // Clone and fuse inside the containing op. 284 SmallVector<Operation *> fusedOps; 285 for (OpOperand *use : uses) { 286 unsigned resultNumber = use->get().cast<OpResult>().getResultNumber(); 287 OpBuilder::InsertionGuard guard(rewriter); 288 rewriter.setInsertionPoint(use->getOwner()); 289 Operation *cloned = rewriter.clone(*producerOp); 290 rewriter.updateRootInPlace( 291 use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); }); 292 fusedOps.push_back(cloned); 293 } 294 295 return fusedOps; 296 } 297 298 DiagnosedSilenceableFailure 299 transform::FuseIntoContainingOp::apply(transform::TransformResults &results, 300 transform::TransformState &state) { 301 SmallVector<Operation *> fusedOps; 302 ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp()); 303 for (Operation *producerOp : producerOps) { 304 if (producerOp->getNumResults() != 1) { 305 Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note); 306 diag << "op with != 1 results not supported"; 307 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); 308 } 309 } 310 ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp()); 311 if (containingOps.size() != 1) 312 return DiagnosedSilenceableFailure( 313 this->emitOpError("requires exactly one containing_op handle")); 314 Operation *containingOp = containingOps.front(); 315 316 // Helper function to find the next producer that should be fused. Take any 317 // producer that has a use inside the containing op. 318 SmallVector<Operation *> remainingProducers(producerOps.begin(), 319 producerOps.end()); 320 auto getNextProducer = [&]() -> FailureOr<Operation *> { 321 for (const auto &it : enumerate(remainingProducers)) { 322 Operation *producerOp = it.value(); 323 bool hasUseInContainingOp = 324 any_of(producerOp->getUsers(), [&](Operation *op) { 325 return containingOp->isProperAncestor(op); 326 }); 327 // TODO: When resolving the TODO below (no duplicate ops), take an op that 328 // has no use among the remaining producers. This is a topological 329 // sorting. 330 if (hasUseInContainingOp) { 331 remainingProducers.erase(remainingProducers.begin() + it.index()); 332 return producerOp; 333 } 334 } 335 return failure(); 336 }; 337 338 IRRewriter rewriter(getContext()); 339 while (!remainingProducers.empty()) { 340 auto nextProducer = getNextProducer(); 341 if (failed(nextProducer)) { 342 Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note); 343 diag << "could not fuse ops into container"; 344 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); 345 } 346 347 Operation *producerOp = *nextProducer; 348 // TODO: If there are multiple uses of the producer in the containing op, we 349 // currently tile/clone the op multiple times (once per use). In some cases, 350 // we can tile/clone once and reuse the value for each use. Futhermore, 351 // producers should then be traversed according to a topological sorting. 352 auto tiled = tileAndFuse(producerOp, containingOp, rewriter); 353 if (succeeded(tiled)) 354 fusedOps.append(*tiled); 355 356 auto cloned = cloneAndFuse(producerOp, containingOp, rewriter); 357 if (succeeded(cloned)) 358 fusedOps.append(*cloned); 359 360 if (failed(tiled) && failed(cloned)) { 361 Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note); 362 diag << "could not fuse into containing op"; 363 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); 364 } 365 } 366 367 results.set(getFusedOp().cast<OpResult>(), fusedOps); 368 return DiagnosedSilenceableFailure::success(); 369 } 370 371 //===----------------------------------------------------------------------===// 372 // GeneralizeOp 373 //===----------------------------------------------------------------------===// 374 375 DiagnosedSilenceableFailure 376 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, 377 SmallVectorImpl<Operation *> &results, 378 transform::TransformState &state) { 379 // Exit early if no transformation is needed. 380 if (isa<GenericOp>(target)) { 381 results.push_back(target); 382 return DiagnosedSilenceableFailure(success()); 383 } 384 FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target); 385 if (succeeded(generic)) { 386 results.push_back(generic->getOperation()); 387 return DiagnosedSilenceableFailure(success()); 388 } 389 results.assign(1, nullptr); 390 return emitDefaultSilenceableFailure(target); 391 } 392 393 //===----------------------------------------------------------------------===// 394 // InterchangeOp 395 //===----------------------------------------------------------------------===// 396 397 DiagnosedSilenceableFailure 398 transform::InterchangeOp::applyToOne(linalg::GenericOp target, 399 SmallVectorImpl<Operation *> &results, 400 transform::TransformState &state) { 401 SmallVector<unsigned> interchangeVector = 402 extractUIntArray(getIteratorInterchange()); 403 // Exit early if no transformation is needed. 404 if (interchangeVector.empty()) { 405 results.push_back(target); 406 return DiagnosedSilenceableFailure(success()); 407 } 408 SimpleRewriter rewriter(target->getContext()); 409 FailureOr<GenericOp> res = 410 interchangeGenericOp(rewriter, target, interchangeVector); 411 if (failed(res)) 412 return DiagnosedSilenceableFailure::definiteFailure(); 413 results.push_back(res->getOperation()); 414 return DiagnosedSilenceableFailure(success()); 415 } 416 417 LogicalResult transform::InterchangeOp::verify() { 418 SmallVector<unsigned> permutation = 419 extractUIntArray(getIteratorInterchange()); 420 auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size())); 421 if (!std::is_permutation(sequence.begin(), sequence.end(), 422 permutation.begin(), permutation.end())) { 423 return emitOpError() 424 << "expects iterator_interchange to be a permutation, found " 425 << getIteratorInterchange(); 426 } 427 return success(); 428 } 429 430 //===---------------------------------------------------------------------===// 431 // MatchOp 432 //===---------------------------------------------------------------------===// 433 434 DiagnosedSilenceableFailure 435 transform::MatchOp::apply(transform::TransformResults &results, 436 transform::TransformState &state) { 437 llvm::StringSet<> strs; 438 if (getOps().has_value()) 439 strs.insert(getOps()->getAsValueRange<StringAttr>().begin(), 440 getOps()->getAsValueRange<StringAttr>().end()); 441 442 ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); 443 if (payloadOps.size() != 1) 444 return DiagnosedSilenceableFailure( 445 this->emitOpError("requires exactly one target handle")); 446 447 SmallVector<Operation *> res; 448 auto matchFun = [&](Operation *op) { 449 if (getOps().has_value() && !strs.contains(op->getName().getStringRef())) 450 return WalkResult::advance(); 451 452 // Interfaces cannot be matched by name, just by ID. 453 // So we specifically encode the interfaces we care about for this op. 454 if (getInterface().has_value()) { 455 auto iface = getInterface().value(); 456 if (iface == transform::MatchInterfaceEnum::LinalgOp && 457 !isa<linalg::LinalgOp>(op)) 458 return WalkResult::advance(); 459 if (iface == transform::MatchInterfaceEnum::TilingInterface && 460 isa<TilingInterface>(op)) 461 return WalkResult::advance(); 462 } 463 464 if (getAttribute().has_value() && !op->hasAttr(getAttribute().value())) 465 return WalkResult::advance(); 466 467 // All constraints are satisfied. 468 res.push_back(op); 469 return WalkResult::advance(); 470 }; 471 472 payloadOps.front()->walk(matchFun); 473 results.set(getResult().cast<OpResult>(), res); 474 return DiagnosedSilenceableFailure(success()); 475 } 476 477 //===---------------------------------------------------------------------===// 478 // MultiTileSizesOp 479 //===---------------------------------------------------------------------===// 480 481 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( 482 LinalgOp target, SmallVector<Operation *> &results, TransformState &state) { 483 OpBuilder builder(target.getContext()); 484 builder.setInsertionPoint(target); 485 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); 486 OpFoldResult divisor = builder.getIndexAttr(getDivisor()); 487 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes( 488 builder, target, getDimension(), targetSize, divisor); 489 if (failed(spec)) { 490 return emitSilenceableError() << "could not generate tile size computation"; 491 } 492 493 AffineExpr s0 = builder.getAffineSymbolExpr(0); 494 AffineExpr s1 = builder.getAffineSymbolExpr(1); 495 Operation *splitPoint = 496 makeComposedAffineApply(builder, target.getLoc(), s0 * s1, 497 {spec->lowTileSize, spec->lowTripCount}); 498 Operation *lowTileSize = spec->lowTileSize.getDefiningOp(); 499 Operation *highTileSize = spec->highTileSize.getDefiningOp(); 500 assert(lowTileSize && highTileSize && splitPoint && 501 "tile sizes are not produced by operations"); 502 results.reserve(results.size() + 3); 503 results.push_back(lowTileSize); 504 results.push_back(highTileSize); 505 results.push_back(splitPoint); 506 return DiagnosedSilenceableFailure::success(); 507 } 508 509 void transform::MultiTileSizesOp::getEffects( 510 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 511 onlyReadsHandle(getTarget(), effects); 512 producesHandle(getResults(), effects); 513 modifiesPayload(effects); 514 } 515 516 //===---------------------------------------------------------------------===// 517 // PadOp 518 //===---------------------------------------------------------------------===// 519 520 DiagnosedSilenceableFailure 521 transform::PadOp::applyToOne(linalg::LinalgOp target, 522 SmallVectorImpl<Operation *> &results, 523 transform::TransformState &state) { 524 // Convert the integer packing flags to booleans. 525 SmallVector<bool> packPaddings; 526 for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings())) 527 packPaddings.push_back(static_cast<bool>(packPadding)); 528 529 // Convert the padding values to attributes. 530 SmallVector<Attribute> paddingValues; 531 for (auto const &it : 532 llvm::zip(getPaddingValues(), target->getOperandTypes())) { 533 Attribute attr = std::get<0>(it); 534 Type elementType = getElementTypeOrSelf(std::get<1>(it)); 535 // Try to parse string attributes to obtain an attribute of element type. 536 if (auto stringAttr = attr.dyn_cast<StringAttr>()) { 537 paddingValues.push_back( 538 parseAttribute(attr.cast<StringAttr>(), elementType)); 539 if (!paddingValues.back()) { 540 auto diag = this->emitOpError("expects a padding that parses to ") 541 << elementType << ", got " << std::get<0>(it); 542 diag.attachNote(target.getLoc()) << "when applied to this op"; 543 return DiagnosedSilenceableFailure::definiteFailure(); 544 } 545 continue; 546 } 547 // Otherwise, add the attribute directly. 548 if (attr.getType() != elementType) { 549 auto diag = this->emitOpError("expects a padding value of type ") 550 << elementType << ", got " << attr; 551 diag.attachNote(target.getLoc()) << "when applied to this op"; 552 return DiagnosedSilenceableFailure::definiteFailure(); 553 } 554 paddingValues.push_back(attr); 555 } 556 557 // Extract the transpose vectors. 558 SmallVector<SmallVector<int64_t>> transposePaddings; 559 for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>()) 560 transposePaddings.push_back( 561 extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>())); 562 563 LinalgPaddingOptions paddingOptions; 564 paddingOptions.setPaddingValues(paddingValues); 565 paddingOptions.setPaddingDimensions( 566 extractFromI64ArrayAttr(getPaddingDimensions())); 567 paddingOptions.setPackPaddings(packPaddings); 568 paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings())); 569 paddingOptions.setTransposePaddings(transposePaddings); 570 571 FailureOr<LinalgOp> result = 572 tryApply<LinalgPaddingPattern>(target, paddingOptions); 573 if (succeeded(result)) { 574 results.push_back(result->getOperation()); 575 return DiagnosedSilenceableFailure(success()); 576 } 577 578 results.assign(1, nullptr); 579 return emitDefaultSilenceableFailure(target); 580 } 581 582 LogicalResult transform::PadOp::verify() { 583 SmallVector<int64_t> packPaddings = 584 extractFromI64ArrayAttr(getPackPaddings()); 585 if (any_of(packPaddings, [](int64_t packPadding) { 586 return packPadding != 0 && packPadding != 1; 587 })) { 588 return emitOpError() 589 << "expects pack_paddings to contain booleans (0/1), found " 590 << getPackPaddings(); 591 } 592 593 SmallVector<int64_t> paddingDimensions = 594 extractFromI64ArrayAttr(getPaddingDimensions()); 595 if (any_of(paddingDimensions, 596 [](int64_t paddingDimension) { return paddingDimension < 0; })) { 597 return emitOpError() 598 << "expects padding_dimensions to contain positive integers, found " 599 << getPaddingDimensions(); 600 } 601 602 SmallVector<int64_t> hoistPaddings = 603 extractFromI64ArrayAttr(getHoistPaddings()); 604 if (any_of(hoistPaddings, 605 [](int64_t hoistPadding) { return hoistPadding < 0; })) { 606 return emitOpError() 607 << "expects hoist_paddings to contain positive integers, found " 608 << getHoistPaddings(); 609 } 610 611 ArrayAttr transposes = getTransposePaddings(); 612 for (Attribute attr : transposes) { 613 SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr); 614 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); 615 if (!std::is_permutation(sequence.begin(), sequence.end(), 616 transpose.begin(), transpose.end())) { 617 return emitOpError() 618 << "expects transpose_paddings to be a permutation, found " 619 << attr; 620 } 621 } 622 return success(); 623 } 624 625 //===----------------------------------------------------------------------===// 626 // PromoteOp 627 //===----------------------------------------------------------------------===// 628 629 DiagnosedSilenceableFailure 630 transform::PromoteOp::applyToOne(linalg::LinalgOp target, 631 SmallVectorImpl<Operation *> &results, 632 transform::TransformState &state) { 633 LinalgPromotionOptions promotionOptions; 634 if (!getOperandsToPromote().empty()) 635 promotionOptions = promotionOptions.setOperandsToPromote( 636 extractFromI64ArrayAttr(getOperandsToPromote())); 637 if (getUseFullTilesByDefault()) 638 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( 639 getUseFullTilesByDefault()); 640 if (getUseAlloca()) 641 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca()); 642 if (!getUseFullTileBuffers().empty()) 643 promotionOptions = promotionOptions.setUseFullTileBuffers( 644 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>())); 645 if (getAlignment().has_value()) 646 promotionOptions = promotionOptions.setAlignment(*getAlignment()); 647 648 if (failed(promoteSubviewsPrecondition(target, promotionOptions))) 649 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 650 651 SimpleRewriter rewriter(target->getContext()); 652 rewriter.setInsertionPoint(target); 653 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions); 654 if (failed(res)) 655 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 656 results.push_back(target); 657 return DiagnosedSilenceableFailure(success()); 658 } 659 660 //===----------------------------------------------------------------------===// 661 // ScalarizeOp 662 //===----------------------------------------------------------------------===// 663 664 DiagnosedSilenceableFailure 665 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, 666 SmallVectorImpl<Operation *> &results, 667 transform::TransformState &state) { 668 LinalgTilingOptions tilingOptions; 669 tilingOptions.scalarizeDynamicDims(); 670 // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile 671 // sizes and asserts that it is not already set. 672 SmallVector<int64_t> emptyTileSizes; 673 LinalgTilingPattern pattern(getContext(), tilingOptions); 674 SimpleRewriter rewriter(getContext()); 675 rewriter.setInsertionPoint(target); 676 FailureOr<TiledLinalgOp> result = 677 pattern.returningMatchAndRewrite(target, rewriter); 678 if (failed(result)) 679 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 680 681 results.push_back(result->op); 682 return DiagnosedSilenceableFailure(success()); 683 } 684 685 //===----------------------------------------------------------------------===// 686 // SplitOp 687 //===----------------------------------------------------------------------===// 688 689 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, 690 TransformState &state) { 691 // Collect the dynamic split points if provided. 692 ArrayRef<Operation *> payload = state.getPayloadOps(getTarget()); 693 SimpleRewriter rewriter(getContext()); 694 SmallVector<OpFoldResult> splitPoints; 695 splitPoints.reserve(payload.size()); 696 if (getDynamicSplitPoint()) { 697 auto diag = DiagnosedSilenceableFailure::success(); 698 splitPoints = llvm::to_vector(llvm::map_range( 699 state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { 700 if (op->getNumResults() != 1 || 701 !op->getResult(0).getType().isIndex()) { 702 diag = emitSilenceableError() 703 << "expected dynamic split point handle to point to a " 704 "single-result index-typed op"; 705 diag.attachNote(op->getLoc()) << "dynamic split point"; 706 } 707 return OpFoldResult(op->getResult(0)); 708 })); 709 if (!diag.succeeded()) 710 return diag; 711 712 if (splitPoints.size() != payload.size()) { 713 emitError() << "expected the dynamic split point handle to point to as " 714 "many operations (" 715 << splitPoints.size() << ") as the target handle (" 716 << payload.size() << ")"; 717 return DiagnosedSilenceableFailure::definiteFailure(); 718 } 719 } else { 720 splitPoints.resize(payload.size(), 721 rewriter.getIndexAttr(getStaticSplitPoint())); 722 } 723 724 // Split each target operation. 725 SmallVector<Operation *> first, second; 726 for (const auto &pair : llvm::zip(payload, splitPoints)) { 727 Operation *target = std::get<0>(pair); 728 auto linalgOp = dyn_cast<LinalgOp>(target); 729 if (!linalgOp) { 730 auto diag = emitSilenceableError() << "only applies to structured ops"; 731 diag.attachNote(target->getLoc()) << "target op"; 732 return diag; 733 } 734 735 if (getDimension() >= linalgOp.getNumLoops()) { 736 auto diag = emitSilenceableError() << "dimension " << getDimension() 737 << " does not exist in target op"; 738 diag.attachNote(target->getLoc()) << "target op"; 739 return diag; 740 } 741 742 rewriter.setInsertionPoint(linalgOp); 743 std::tie(first.emplace_back(), second.emplace_back()) = 744 linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair)); 745 } 746 747 results.set(getFirst().cast<OpResult>(), first); 748 results.set(getSecond().cast<OpResult>(), second); 749 return DiagnosedSilenceableFailure::success(); 750 } 751 752 void SplitOp::getEffects( 753 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 754 consumesHandle(getTarget(), effects); 755 if (getDynamicSplitPoint()) 756 onlyReadsHandle(getDynamicSplitPoint(), effects); 757 producesHandle(getResults(), effects); 758 modifiesPayload(effects); 759 } 760 761 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { 762 OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; 763 IntegerAttr staticSplitPoint; 764 auto pdlOperationType = 765 pdl::OperationType::get(parser.getBuilder().getContext()); 766 if (parser.parseOperand(target) || 767 parser.resolveOperand(target, pdlOperationType, result.operands) || 768 parser.parseKeyword("after")) 769 return failure(); 770 771 OptionalParseResult dynamicPointParseResult = 772 parser.parseOptionalOperand(dynamicSplitPoint); 773 if (!dynamicPointParseResult.hasValue()) { 774 int64_t staticSplitPointValue; 775 if (failed(parser.parseInteger(staticSplitPointValue))) 776 return failure(); 777 778 staticSplitPoint = 779 parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); 780 } else { 781 if (failed(*dynamicPointParseResult) || 782 parser.resolveOperand(dynamicSplitPoint, pdlOperationType, 783 result.operands)) { 784 return failure(); 785 } 786 787 staticSplitPoint = 788 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize); 789 } 790 791 result.addAttribute( 792 SplitOp::getStaticSplitPointAttrName(result.name).getValue(), 793 staticSplitPoint); 794 if (failed(parser.parseOptionalAttrDict(result.attributes))) 795 return failure(); 796 797 result.addTypes({pdlOperationType, pdlOperationType}); 798 return success(); 799 } 800 801 void SplitOp::print(OpAsmPrinter &printer) { 802 printer << " " << getTarget() << " after "; 803 int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint()); 804 if (staticSplitSize != ShapedType::kDynamicSize) 805 printer << staticSplitSize; 806 else 807 printer << getDynamicSplitPoint(); 808 printer << " "; 809 printer.printOptionalAttrDict(getOperation()->getAttrs(), 810 {getStaticSplitPointAttrName()}); 811 } 812 813 LogicalResult SplitOp::verify() { 814 if ((static_cast<int64_t>(getStaticSplitPoint()) != 815 ShapedType::kDynamicSize) ^ 816 (getDynamicSplitPoint() == nullptr)) { 817 return emitOpError() 818 << "expects either a dynamic or a static split point to be provided"; 819 } 820 return success(); 821 } 822 823 //===----------------------------------------------------------------------===// 824 // SplitReductionOp 825 //===----------------------------------------------------------------------===// 826 827 DiagnosedSilenceableFailure 828 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, 829 SmallVectorImpl<Operation *> &results, 830 transform::TransformState &state) { 831 ControlSplitReductionFn splitFn = [&](LinalgOp) { 832 return std::pair<int64_t, unsigned>(getSplitFactor(), 833 getInsertSplitDimension()); 834 }; 835 SimpleRewriter rewriter(getContext()); 836 rewriter.setInsertionPoint(target); 837 FailureOr<SplitReductionResult> splitResult = 838 (getUseScalingAlgorithm()) 839 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) 840 : splitReduction(rewriter, target, splitFn, getUseAlloc()); 841 if (failed(splitResult)) 842 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 843 844 results.push_back(splitResult->initOrAlloc); 845 results.push_back(splitResult->fillOp); 846 results.push_back(splitResult->splitLinalgOp); 847 results.push_back(splitResult->resultCombiningLinalgOp); 848 return DiagnosedSilenceableFailure(success()); 849 } 850 851 //===----------------------------------------------------------------------===// 852 // TileOp 853 //===----------------------------------------------------------------------===// 854 855 DiagnosedSilenceableFailure 856 transform::TileOp::apply(TransformResults &transformResults, 857 TransformState &state) { 858 LinalgTilingOptions tilingOptions; 859 SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes()); 860 861 ArrayRef<Operation *> targets = state.getPayloadOps(getTarget()); 862 SmallVector<ArrayRef<Operation *>> dynamicSizeProducers; 863 dynamicSizeProducers.reserve(getDynamicSizes().size()); 864 for (Value dynamicSizeProducerHandle : getDynamicSizes()) { 865 dynamicSizeProducers.push_back( 866 state.getPayloadOps(dynamicSizeProducerHandle)); 867 868 if (dynamicSizeProducers.back().size() != targets.size()) { 869 DiagnosedSilenceableFailure diag = 870 emitSilenceableError() 871 << "expected as many dynamic size-producing operations (" 872 << dynamicSizeProducers.back().size() << ") as target ops (" 873 << targets.size() << ")"; 874 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; 875 return diag; 876 } 877 878 for (Operation *op : dynamicSizeProducers.back()) { 879 if (op->getNumResults() == 1 && 880 op->getResult(0).getType().isa<IndexType>()) 881 continue; 882 DiagnosedSilenceableFailure diag = 883 emitSilenceableError() << "expected sizes to be produced by ops " 884 "with a single index-type result"; 885 diag.attachNote(op->getLoc()) << "size producer op"; 886 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; 887 return diag; 888 } 889 } 890 891 SmallVector<Operation *> tiled; 892 SmallVector<SmallVector<Operation *, 4>, 4> loops; 893 loops.resize(getLoops().size()); 894 for (auto &en : llvm::enumerate(targets)) { 895 auto linalgOp = dyn_cast<LinalgOp>(en.value()); 896 if (!linalgOp) { 897 DiagnosedSilenceableFailure diag = emitSilenceableError() 898 << "only linalg ops are supported"; 899 diag.attachNote(en.value()->getLoc()) << "target op"; 900 return diag; 901 } 902 903 unsigned index = en.index(); 904 if (!tileSizes.empty()) { 905 tilingOptions.setTileSizeComputationFunction( 906 [&, index](OpBuilder &b, Operation *) { 907 SmallVector<Value, 4> sizes; 908 sizes.reserve(tileSizes.size()); 909 unsigned dynamicIdx = 0; 910 for (OpFoldResult ofr : getMixedSizes()) { 911 if (auto attr = ofr.dyn_cast<Attribute>()) { 912 sizes.push_back(b.create<arith::ConstantIndexOp>( 913 getLoc(), attr.cast<IntegerAttr>().getInt())); 914 } else { 915 sizes.push_back( 916 dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); 917 } 918 } 919 return sizes; 920 }); 921 } 922 923 tilingOptions.setInterchange(extractUIntArray(getInterchange())); 924 LinalgTilingPattern pattern(getContext(), tilingOptions); 925 SimpleRewriter rewriter(linalgOp.getContext()); 926 FailureOr<TiledLinalgOp> tiledOp = 927 pattern.returningMatchAndRewrite(linalgOp, rewriter); 928 if (failed(tiledOp)) 929 return DiagnosedSilenceableFailure::definiteFailure(); 930 931 tiled.push_back(tiledOp->op); 932 for (const auto &en2 : llvm::enumerate(tiledOp->loops)) 933 loops[en2.index()].push_back(en2.value()); 934 } 935 936 transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled); 937 for (const auto &en : llvm::enumerate(loops)) 938 transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value()); 939 940 return DiagnosedSilenceableFailure::success(); 941 } 942 943 SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() { 944 ValueRange dynamic = getDynamicSizes(); 945 SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes()); 946 SmallVector<OpFoldResult> results; 947 results.reserve(tileSizes.size()); 948 unsigned dynamicPos = 0; 949 Builder builder(getContext()); 950 for (int64_t size : tileSizes) { 951 if (size == ShapedType::kDynamicSize) { 952 results.push_back(dynamic[dynamicPos++]); 953 } else { 954 results.push_back(builder.getIndexAttr(size)); 955 } 956 } 957 return results; 958 } 959 960 ParseResult transform::TileOp::parse(OpAsmParser &parser, 961 OperationState &result) { 962 OpAsmParser::UnresolvedOperand target; 963 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes; 964 ArrayAttr staticSizes; 965 auto pdlOperationType = pdl::OperationType::get(parser.getContext()); 966 if (parser.parseOperand(target) || 967 parser.resolveOperand(target, pdlOperationType, result.operands) || 968 parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) || 969 parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || 970 parser.parseOptionalAttrDict(result.attributes)) 971 return ParseResult::failure(); 972 973 result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); 974 size_t numExpectedLoops = 975 staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0); 976 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType)); 977 return success(); 978 } 979 980 void TileOp::print(OpAsmPrinter &p) { 981 p << ' ' << getTarget(); 982 printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(), 983 getStaticSizes()); 984 p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); 985 } 986 987 void transform::TileOp::getEffects( 988 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 989 consumesHandle(getTarget(), effects); 990 onlyReadsHandle(getDynamicSizes(), effects); 991 producesHandle(getTiledLinalgOp(), effects); 992 producesHandle(getLoops(), effects); 993 modifiesPayload(effects); 994 } 995 996 //===----------------------------------------------------------------------===// 997 // TileToForeachThreadOp 998 //===----------------------------------------------------------------------===// 999 1000 DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne( 1001 TilingInterface target, SmallVectorImpl<Operation *> &results, 1002 transform::TransformState &state) { 1003 IRRewriter rewriter(getContext()); 1004 rewriter.setInsertionPoint(target); 1005 auto maybeThreadDimMappingAttr = getThreadDimMapping(); 1006 auto dimMapping = 1007 llvm::to_vector(maybeThreadDimMappingAttr 1008 ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) 1009 : ArrayRef<int64_t>{}); 1010 1011 FailureOr<ForeachThreadTilingResult> tilingResult = failure(); 1012 if (Optional<ArrayAttr> numThreads = getNumThreads()) 1013 tilingResult = linalg::tileToForeachThreadOp( 1014 rewriter, target, getAsOpFoldResult(*numThreads), dimMapping); 1015 1016 if (Optional<ArrayAttr> tileSizes = getTileSizes()) 1017 tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( 1018 rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping); 1019 1020 if (failed(tilingResult)) 1021 return emitDefaultSilenceableFailure(target); 1022 rewriter.replaceOp(target, tilingResult->tileOp->getResults()); 1023 results.assign({tilingResult->tileOp, tilingResult->tiledOp}); 1024 return DiagnosedSilenceableFailure(success()); 1025 } 1026 1027 //===----------------------------------------------------------------------===// 1028 // VectorizeOp 1029 //===----------------------------------------------------------------------===// 1030 1031 DiagnosedSilenceableFailure 1032 transform::VectorizeOp::applyToOne(Operation *target, 1033 SmallVectorImpl<Operation *> &results, 1034 transform::TransformState &state) { 1035 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 1036 auto diag = this->emitOpError("requires isolated-from-above targets"); 1037 diag.attachNote(target->getLoc()) << "non-isolated target"; 1038 return DiagnosedSilenceableFailure::definiteFailure(); 1039 } 1040 1041 MLIRContext *ctx = getContext(); 1042 RewritePatternSet patterns(ctx); 1043 patterns.add<LinalgVectorizationPattern>(ctx); 1044 1045 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 1046 vector::populateVectorReductionToContractPatterns(patterns); 1047 patterns.add<linalg::LinalgCopyVTRForwardingPattern, 1048 linalg::LinalgCopyVTWForwardingPattern>(ctx, 1049 /*benefit=*/2); 1050 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); 1051 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); 1052 if (getVectorizePadding()) 1053 linalg::populatePadOpVectorizationPatterns(patterns); 1054 1055 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) 1056 return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); 1057 1058 results.push_back(target); 1059 return DiagnosedSilenceableFailure(success()); 1060 } 1061 1062 //===----------------------------------------------------------------------===// 1063 // Transform op registration 1064 //===----------------------------------------------------------------------===// 1065 1066 namespace { 1067 /// Registers new ops and declares PDL as dependent dialect since the additional 1068 /// ops are using PDL types for operands and results. 1069 class LinalgTransformDialectExtension 1070 : public transform::TransformDialectExtension< 1071 LinalgTransformDialectExtension> { 1072 public: 1073 using Base::Base; 1074 1075 void init() { 1076 declareDependentDialect<pdl::PDLDialect>(); 1077 1078 declareGeneratedDialect<AffineDialect>(); 1079 declareGeneratedDialect<arith::ArithmeticDialect>(); 1080 declareGeneratedDialect<scf::SCFDialect>(); 1081 declareGeneratedDialect<vector::VectorDialect>(); 1082 1083 registerTransformOps< 1084 #define GET_OP_LIST 1085 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 1086 >(); 1087 } 1088 }; 1089 } // namespace 1090 1091 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" 1092 1093 #define GET_OP_CLASSES 1094 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 1095 1096 void mlir::linalg::registerTransformDialectExtension( 1097 DialectRegistry ®istry) { 1098 registry.addExtensions<LinalgTransformDialectExtension>(); 1099 } 1100