1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "llvm/ADT/ScopeExit.h" 28 #include "llvm/Support/Debug.h" 29 #include "llvm/Support/raw_ostream.h" 30 #include <type_traits> 31 32 #define DEBUG_TYPE "linalg-transforms" 33 34 using namespace mlir; 35 using namespace mlir::linalg; 36 37 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 38 39 //===----------------------------------------------------------------------===// 40 // Transformations exposed as rewrite patterns. 41 //===----------------------------------------------------------------------===// 42 // Marker used as attribute name in generated Linalg rewriting transformations. 43 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 44 "__internal_linalg_transform__"; 45 46 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 47 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 48 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 49 replacement(replacement) {} 50 51 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 52 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 53 Optional<Identifier> replacement) 54 : filters(), 55 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 56 replacement(replacement) { 57 if (f) 58 filters.push_back(f); 59 } 60 61 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 62 PatternRewriter &rewriter, Operation *op) const { 63 if (llvm::any_of(filters, 64 [&](const FilterFunction &f) { return failed(f(op)); })) 65 return failure(); 66 67 auto attr = op->template getAttrOfType<StringAttr>( 68 LinalgTransforms::kLinalgTransformMarker); 69 70 if (!attr) { 71 // 1. Has no filter case and matchDisjunction is empty. 72 if (matchDisjunction.empty()) 73 return success(); 74 75 // 2. Has no filter but was expecting a filter. 76 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 77 diag << " does not have any filter from list: "; 78 interleaveComma(matchDisjunction, diag); 79 }); 80 } 81 82 // 4. Match explicit filter. 83 for (auto filter : matchDisjunction) 84 if (attr.getValue() == filter) 85 return success(); 86 87 // 5. Fail to match. 88 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 89 diag << " does not have any filter from list: "; 90 interleaveComma(matchDisjunction, diag); 91 }); 92 } 93 94 void mlir::linalg::LinalgTransformationFilter:: 95 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 96 Operation *op) const { 97 if (replacement.hasValue()) 98 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 99 rewriter.getStringAttr(replacement.getValue().strref())); 100 else 101 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 102 rewriter.getContext())); 103 } 104 105 LinalgTilingOptions & 106 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 107 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 108 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 109 OpBuilder::InsertionGuard guard(b); 110 b.setInsertionPointToStart( 111 &op->getParentOfType<FuncOp>().getBody().front()); 112 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 113 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 114 return v; 115 })); 116 }; 117 return *this; 118 } 119 120 /// Try to compute a static bounding box for `operand` 121 /// Return success if either: 122 /// 1. The operand is already statically shaped, `result` is left unchanged. 123 /// 2. The operand is (partially) dynamic, `result` is the result of a freshly 124 /// created PadTensorOp. 125 /// Return failure if the operand cannot be padded to a static shape. 126 static LogicalResult padOperandToSmallestStaticBoundingBox( 127 PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, 128 const LinalgTilingOptions &options, Value &result) { 129 // Already static shape, no need to pad. 130 if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic)) 131 return success(); 132 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 133 // Not a slice op, cannot construct a static bounding box. 134 if (!sliceOp) 135 return failure(); 136 SmallVector<int64_t> staticSizes; 137 staticSizes.reserve(opToPad.getRank(opOperand)); 138 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 139 for (auto size : shapedOp.getMixedSizes()) { 140 auto indexAttr = size.is<Attribute>() 141 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 142 : linalg::getSmallestBoundingIndex(size.get<Value>()); 143 // SmallestBoundingIndex must exist for all sizes. 144 // For now return an error if we can't find it. 145 if (!indexAttr) 146 return rewriter.notifyMatchFailure( 147 opToPad, "No constant bounding box can be found for padding"); 148 staticSizes.push_back(indexAttr.getInt()); 149 } 150 Value pad = options.paddingValueComputationFunction(rewriter, *opOperand); 151 auto staticTensorType = RankedTensorType::get( 152 staticSizes, getElementTypeOrSelf(opOperand->get())); 153 result = linalg::PadTensorOp::createPadHighOp( 154 staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter); 155 return success(); 156 } 157 158 // Try to create a static bounding box around each operand of `res.op`. 159 // If successful, `res.op` is rewritten in static form with padded operands. 160 // `res.op` is updated to the cloned static form of the op on success. 161 static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter, 162 TiledLinalgOp &res, 163 const LinalgTilingOptions &options) { 164 LinalgOp opToPad = res.op; 165 Location loc = opToPad->getLoc(); 166 167 // If the op is fully static, it does not need padding. 168 // TODO: there are cases where we may still want to pad to larger sizes. 169 assert(opToPad.hasTensorSemantics() && 170 "expected operation to have tensor semantics"); 171 if (!opToPad.hasDynamicShape()) 172 return success(); 173 174 OpBuilder::InsertionGuard g(rewriter); 175 // Set IP after op because we also take the dims of the original output. 176 rewriter.setInsertionPointAfter(opToPad); 177 // Make a copy of the shaped operands and update it. 178 SmallVector<Value> newOperands; 179 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 180 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 181 Value paddedOperand; 182 // If padding was requested but the shape cannot be bounded statically then 183 // the pattern fails to apply. 184 if (failed(padOperandToSmallestStaticBoundingBox( 185 rewriter, opToPad, opOperand, options, paddedOperand))) 186 return failure(); 187 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 188 } 189 190 // Clone `opToPad` to operate on the statically padded shapes. 191 auto resultTensorTypes = 192 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 193 linalg::LinalgOp paddedOp = 194 opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); 195 196 // Recover the slice out of the new static results. This keeps the original 197 // linalg op around because it uses the dims of the original results. 198 // This later folds away. 199 SmallVector<Value> paddedSubviewResults; 200 paddedSubviewResults.reserve(opToPad->getNumResults()); 201 SetVector<Operation *> newUsersOfOpToPad; 202 for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) { 203 auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank(); 204 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 205 auto sizes = llvm::to_vector<4>(llvm::map_range( 206 llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult { 207 auto dimOp = rewriter.create<tensor::DimOp>(loc, std::get<0>(it), d); 208 newUsersOfOpToPad.insert(dimOp); 209 return dimOp.getResult(); 210 })); 211 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 212 paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>( 213 loc, std::get<1>(it), offsets, sizes, strides)); 214 } 215 // Replace the transient `opToPad` locally, except for uses that we just 216 // created for the purpose of extracting the dims. 217 rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) { 218 return !newUsersOfOpToPad.contains(opOp.getOwner()); 219 }); 220 221 res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults}; 222 return success(); 223 } 224 225 /// Linalg base tiling pattern. 226 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 227 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 228 LinalgTransformationFilter filter, PatternBenefit benefit) 229 : RewritePattern(opName, benefit, context), filter(filter), 230 options(options) {} 231 232 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 233 MLIRContext *context, LinalgTilingOptions options, 234 LinalgTransformationFilter filter, PatternBenefit benefit) 235 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 236 options(options) {} 237 238 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 239 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 240 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 241 if (!linalgOp) 242 return failure(); 243 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 244 return failure(); 245 246 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 247 248 if (!res) 249 return failure(); 250 251 // Setup RAII guard to return properly. 252 LinalgOp tiledOp = res->op; 253 auto guard = llvm::make_scope_exit([&]() { 254 // Return relevant information to derived pattern. 255 result = *res; 256 // Replace filter on both tiledOp and tiledAndPaddedOp, if necessary. 257 filter.replaceLinalgTransformationFilter(rewriter, tiledOp); 258 if (tiledOp != res->op) 259 filter.replaceLinalgTransformationFilter(rewriter, res->op); 260 }); 261 262 // Consider padding on the fly only if the op has tensor semantics. 263 if (!options.paddingValueComputationFunction || 264 !linalgOp.hasTensorSemantics()) 265 return success(); 266 267 // Try to pad on the fly by rewriting res->op as a padded op. 268 if (failed(rewriteAsPaddedOp(rewriter, *res, options))) { 269 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 270 return failure(); 271 } 272 273 // Do not perform replacement of `linalgOp`, let the derived patterns 274 // do this as they see fit, from the resulting TiledLinalgOp. 275 return success(); 276 } 277 278 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 279 if (tiledOp.loops.empty()) 280 return tiledOp.op.getOperation()->getResults(); 281 return tiledOp.loops.front()->getResults(); 282 } 283 284 static ValueRange 285 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 286 if (tiledAndFusedOp.fusedLoops.empty()) 287 return tiledAndFusedOp.op.getOperation()->getResults(); 288 return tiledAndFusedOp.fusedLoops.front()->getResults(); 289 } 290 291 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 292 StringRef opName, MLIRContext *context, 293 const LinalgDependenceGraph &dependenceGraph, 294 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 295 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 296 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 297 : RewritePattern(opName, benefit, context, {}), 298 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 299 fusionOptions(fusionOptions), filter(filter), 300 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 301 302 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 303 Operation *op, PatternRewriter &rewriter) const { 304 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 305 // TODO: remove hasIndexSemantics check once index ops are supported. 306 if (!linalgOp || linalgOp.hasIndexSemantics()) 307 return failure(); 308 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 309 return failure(); 310 311 DenseSet<Operation *> producers; 312 producers.insert(linalgOp); 313 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 314 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 315 // When looking at dependences into, indexingOp is always OpOperand. We 316 // could assert, but continue if this is not the case. 317 if (!operandNumber) 318 continue; 319 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 320 continue; 321 if (isa<LinalgOp>(dependence.getDependentOp())) 322 producers.insert(dependence.getDependentOp()); 323 } 324 325 SmallVector<LinalgOp, 1> fusionOps; 326 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 327 ++it) { 328 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 329 if (producerLinalgOp && producers.count(producerLinalgOp)) 330 fusionOps.push_back(producerLinalgOp); 331 } 332 fusionOps.push_back(linalgOp); 333 334 SmallVector<Value, 4> tileSizes = 335 tilingOptions.tileSizeComputationFunction(rewriter, op); 336 LinalgTilingOptions instanceTilingOptions = tilingOptions; 337 instanceTilingOptions.setTileSizes(tileSizes); 338 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 339 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 340 if (!tiledAndFusedOps) 341 return failure(); 342 343 // Tile the unfused loops; 344 SmallVector<Value, 4> unfusedLoopTileSizes; 345 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 346 for (auto tileSize : enumerate(tileSizes)) { 347 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 348 unfusedLoopTileSizes.push_back(zero); 349 else 350 unfusedLoopTileSizes.push_back(tileSize.value()); 351 } 352 // Tile the loop only if there is a non-zero tile size. 353 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 354 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 355 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 356 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 357 return cst.getValue() != 0; 358 return true; 359 })) { 360 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 361 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 362 Optional<TiledLinalgOp> unfusedTiledOp = 363 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 364 if (!unfusedTiledOp) 365 return failure(); 366 rewriter.replaceOp(tiledAndFusedOps->op, 367 getTiledOpResult(unfusedTiledOp.getValue())); 368 tiledAndFusedOps->op = unfusedTiledOp->op; 369 } 370 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 371 372 filter.replaceLinalgTransformationFilter(rewriter, 373 tiledAndFusedOps->op.getOperation()); 374 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 375 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 376 fusedOp.getOperation()); 377 } 378 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 379 originalOpMarker.replaceLinalgTransformationFilter( 380 rewriter, origProducerOp.getOperation()); 381 } 382 rewriter.updateRootInPlace(op, [&]() { 383 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 384 }); 385 return success(); 386 } 387 388 /// Linalg generic interchange pattern. 389 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 390 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 391 LinalgTransformationFilter filter, PatternBenefit benefit) 392 : OpRewritePattern(context, benefit), filter(filter), 393 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 394 395 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 396 GenericOp genericOp, PatternRewriter &rewriter) const { 397 if (failed(filter.checkAndNotify(rewriter, genericOp))) 398 return failure(); 399 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 400 return failure(); 401 402 // TODO: figure out how this interplays with named ops. In particular this 403 // should break the named op property. 404 rewriter.updateRootInPlace(genericOp, [&]() { 405 interchangeGenericOp(rewriter, genericOp, interchangeVector); 406 // New filter if specified. 407 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 408 }); 409 return success(); 410 } 411 412 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 413 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 414 LinalgTransformationFilter filter, PatternBenefit benefit) 415 : RewritePattern(opName, benefit, context, {}), filter(filter), 416 options(options) {} 417 418 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 419 Operation *op, PatternRewriter &rewriter) const { 420 if (failed(filter.checkAndNotify(rewriter, op))) 421 return failure(); 422 if (failed(promoteSubviewsPrecondition(op, options))) 423 return failure(); 424 425 // TODO: We cannot use root update here. This pattern is creating other ops, 426 // so if the promotion fails, those need to be cleaned up, which doesnt seem 427 // to be happening here. So to fail properly, we should be cloning the op and 428 // deleting the previous op. This needs more investigation. 429 rewriter.startRootUpdate(op); 430 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 431 if (!promotedOp) { 432 rewriter.cancelRootUpdate(op); 433 return op->emitError("subview promotion failed"); 434 } 435 rewriter.finalizeRootUpdate(op); 436 filter.replaceLinalgTransformationFilter(rewriter, op); 437 return success(); 438 } 439 440 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 441 MLIRContext *context, LinalgTransformationFilter filter, 442 PatternBenefit benefit) 443 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 444 445 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 446 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 447 PatternBenefit benefit) 448 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 449 450 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 451 Operation *op, PatternRewriter &rewriter) const { 452 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 453 if (!linalgOp) 454 return failure(); 455 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 456 return failure(); 457 SmallVector<Value> newResults; 458 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 459 return failure(); 460 if (!newResults.empty()) 461 rewriter.replaceOp(op, newResults); 462 else 463 rewriter.eraseOp(op); 464 return success(); 465 } 466 467 LogicalResult mlir::linalg::applyStagedPatterns( 468 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 469 const FrozenRewritePatternSet &stage2Patterns, 470 function_ref<LogicalResult(Operation *)> stage3Lambda) { 471 unsigned iteration = 0; 472 (void)iteration; 473 for (const auto &patterns : stage1Patterns) { 474 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 475 << *op); 476 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 477 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 478 return failure(); 479 } 480 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 481 << *op); 482 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 483 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 484 return failure(); 485 } 486 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 487 << *op); 488 if (stage3Lambda) { 489 if (failed(stage3Lambda(op))) 490 return failure(); 491 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 492 << *op); 493 } 494 } 495 return success(); 496 } 497 498 /// Traverse the `dims` and substitute known min or max expressions returned by 499 /// the lambda |getMinMaxExpr|. 500 static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims, 501 SmallVectorImpl<Value> &symbols, 502 GetMinMaxExprFn getMinMaxExpr) { 503 auto exprs = llvm::to_vector<4>(map.getResults()); 504 for (AffineExpr &expr : exprs) { 505 bool substituted = true; 506 while (substituted) { 507 substituted = false; 508 for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) { 509 Value dim = dims[dimIdx]; 510 auto minMax = getMinMaxExpr(dim, dims, symbols); 511 if (!minMax) 512 continue; 513 AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext()); 514 LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n"); 515 LLVM_DEBUG(DBGS() << "Before: " << expr << "\n"); 516 // Substitute occurrences of `dimExpr` by either the min expression or 517 // the max expression depending on whether the value is used with a 518 // positive or negative coefficient. 519 AffineExpr substitutedExpr = 520 substWithMin(expr, dimExpr, minMax->first, minMax->second); 521 LLVM_DEBUG(DBGS() << "After: " << substitutedExpr << "\n"); 522 substituted = (substitutedExpr != expr); 523 expr = substitutedExpr; 524 } 525 } 526 527 // Cleanup and simplify the results. 528 // This needs to happen outside of the loop iterating on dims.size() since 529 // it modifies dims. 530 SmallVector<Value, 4> operands(dims.begin(), dims.end()); 531 operands.append(symbols.begin(), symbols.end()); 532 auto map = AffineMap::get(dims.size(), symbols.size(), exprs, 533 exprs.front().getContext()); 534 535 LLVM_DEBUG({ 536 DBGS() << "Map to simplify: " << map << "\n"; 537 DBGS() << "Operands:\n"; 538 for (Value v : operands) 539 DBGS() << v << "\n"; 540 }); 541 542 // Pull in affine.apply operations and compose them fully into the 543 // result. 544 fullyComposeAffineMapAndOperands(&map, &operands); 545 canonicalizeMapAndOperands(&map, &operands); 546 map = simplifyAffineMap(map); 547 // Assign the results. 548 exprs.assign(map.getResults().begin(), map.getResults().end()); 549 dims.assign(operands.begin(), operands.begin() + map.getNumDims()); 550 symbols.assign(operands.begin() + map.getNumDims(), operands.end()); 551 552 LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n"); 553 } 554 555 assert(!exprs.empty() && "Unexpected empty exprs"); 556 return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext()); 557 } 558 559 /// Traverse the dims of the AffineMap of `affineMinOp` and substitute 560 /// dimensions with known range by new expressions involving the min or max 561 /// expression: 562 /// - If the AffineDimExpr mapped to a known value has a positive sign, it 563 /// is replaced by the min expression. 564 /// - If the AffineDimExpr mapped to a known value has a negative sign, it is 565 /// replaced by the max expression. 566 /// All known values are iteratively replaced. 567 /// This is used as an intermediate step in computing bounding boxes and 568 /// canonicalize AffineMinOps. All dim and symbol operands are assumed to have 569 /// positive values (positive orthant assumptions). 570 /// Return a new AffineMap, dims and symbols that have been canonicalized and 571 /// simplified. 572 AffineMapAndOperands 573 mlir::linalg::substituteMin(AffineMinOp affineMinOp, 574 GetMinMaxExprFn getMinMaxExpr) { 575 AffineMapAndOperands res{affineMinOp.getAffineMap(), 576 SmallVector<Value>(affineMinOp.getDimOperands()), 577 SmallVector<Value>(affineMinOp.getSymbolOperands())}; 578 res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols, 579 getMinMaxExpr); 580 return res; 581 } 582 583 LogicalResult AffineMinRangeCanonicalizationPattern::matchAndRewrite( 584 AffineMinOp minOp, PatternRewriter &rewriter) const { 585 LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() 586 << "\n"); 587 588 auto affineMapAndOperands = substituteMin(minOp, getMinMaxFn); 589 AffineMap map = affineMapAndOperands.map; 590 591 LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n"); 592 593 // Check whether any of the expressions, when subtracted from all other 594 // expressions, produces only >= 0 constants. If so, it is the min. 595 for (auto e : minOp.getAffineMap().getResults()) { 596 LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n"); 597 if (!e.isSymbolicOrConstant()) 598 continue; 599 600 auto isNonPositive = [](AffineExpr e) { 601 if (auto cst = e.dyn_cast<AffineConstantExpr>()) 602 return cst.getValue() < 0; 603 return true; 604 }; 605 606 // Build the subMap and check everything is statically known to be 607 // positive. 608 SmallVector<AffineExpr, 4> subExprs; 609 subExprs.reserve(map.getNumResults()); 610 for (auto ee : map.getResults()) 611 subExprs.push_back(ee - e); 612 MLIRContext *ctx = minOp.getContext(); 613 AffineMap subMap = simplifyAffineMap( 614 AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx)); 615 LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n"); 616 if (llvm::any_of(subMap.getResults(), isNonPositive)) 617 continue; 618 619 // Static min found. 620 if (auto cst = e.dyn_cast<AffineConstantExpr>()) { 621 rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue()); 622 } else { 623 auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx); 624 SmallVector<Value> resultOperands = affineMapAndOperands.dims; 625 llvm::append_range(resultOperands, affineMapAndOperands.symbols); 626 canonicalizeMapAndOperands(&resultMap, &resultOperands); 627 resultMap = simplifyAffineMap(resultMap); 628 rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap, 629 resultOperands); 630 } 631 return success(); 632 } 633 634 return failure(); 635 } 636 637 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 638 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 639 } 640 641 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 642 /// with pad_val) and GenericOp (to copy contents). 643 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 644 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 645 646 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 647 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 648 649 // Bail on non-static shapes. 650 if (!inputShapedType.hasStaticShape()) 651 return failure(); 652 if (!resultShapedType.hasStaticShape()) 653 return failure(); 654 655 // Only support padding with a constant for now, i.e. either: 656 // 1. A BBarg from a different block. 657 // 2. A value defined outside of the current block. 658 Block &block = padOp.region().front(); 659 auto yieldOp = cast<YieldOp>(block.getTerminator()); 660 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 661 Value padValue = yieldOp.values().front(); 662 Operation *definingOp = padValue.getDefiningOp(); 663 if (definingOp && definingOp->getBlock() == &block) 664 return failure(); 665 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 666 return failure(); 667 668 // Create tensor with the padded shape 669 Location loc = padOp.getLoc(); 670 SmallVector<Value> indices(resultShapedType.getRank(), 671 rewriter.create<ConstantIndexOp>(loc, 0)); 672 Value initTensor = rewriter.create<InitTensorOp>( 673 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 674 675 // Initialize tensor with the pad value 676 Value tmpTensor = 677 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 678 679 // Copy original contents into new tensor 680 // Uses linalg.generic, but could be done with tensor.insert_slice 681 SmallVector<AffineExpr, 4> outputExprs; 682 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 683 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 684 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 685 } 686 687 SmallVector<AffineMap, 2> transferMaps = { 688 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 689 AffineMap::get(resultShapedType.getRank(), 690 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 691 692 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 693 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 694 getNParallelLoopsAttrs(resultShapedType.getRank()), 695 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 696 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 697 }); 698 699 return success(); 700 } 701 702 /// Filling `dest` using FillOp constant padding value if possible. 703 /// Otherwise, generate a tensor::GenerateOp. 704 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 705 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 706 const SmallVector<Value> &dynSizes) const { 707 auto padValue = padOp.getConstantPaddingValue(); 708 if (padValue) 709 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 710 711 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 712 auto generateOp = rewriter.create<tensor::GenerateOp>( 713 padOp.getLoc(), padOp.getResultType(), dynSizes); 714 // Copy region to new op. 715 BlockAndValueMapping bvm; 716 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 717 // Rewrite linalg::YieldOp to tensor::YieldOp. 718 OpBuilder::InsertionGuard guard(rewriter); 719 auto yieldOp = 720 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 721 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 722 assert(yieldOp.values().size() == 1); 723 rewriter.setInsertionPoint(yieldOp); 724 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 725 return generateOp; 726 } 727 728 LogicalResult 729 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 730 PatternRewriter &rewriter) const { 731 // Given an OpFoldResult, return an index-typed value. 732 auto getIdxValue = [&](OpFoldResult ofr) { 733 if (auto val = ofr.dyn_cast<Value>()) 734 return val; 735 return rewriter 736 .create<ConstantIndexOp>( 737 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 738 .getResult(); 739 }; 740 741 auto resultType = padOp.getResultType(); 742 // Compute size of InitTensorOp. Any combination of static/dynamic is 743 // supported. 744 SmallVector<Value> dynSizes; 745 SmallVector<int64_t> staticSizes; 746 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 747 if (resultType.isDynamicDim(dim)) { 748 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 749 padOp.source(), dim); 750 // Add low and high padding value. 751 auto plusLow = rewriter.createOrFold<AddIOp>( 752 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 753 auto plusHigh = rewriter.createOrFold<AddIOp>( 754 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 755 dynSizes.push_back(plusHigh); 756 } 757 staticSizes.push_back(resultType.getDimSize(dim)); 758 } 759 760 // Init tensor and fill it with padding. 761 Value init = rewriter.create<InitTensorOp>( 762 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 763 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 764 765 // Try optimize the copy of source. 766 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 767 return success(); 768 769 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 770 // for copying the PadOp source. 771 auto sourceType = padOp.getSourceType(); 772 // Compute size of source of PadTensorOp. 773 SmallVector<OpFoldResult> srcSizes; 774 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 775 if (sourceType.isDynamicDim(dim)) { 776 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 777 padOp.getLoc(), padOp.source(), dim)); 778 } else { 779 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 780 } 781 } 782 // Strides of InsertSliceOp are all 1. 783 SmallVector<OpFoldResult> strides(sourceType.getRank(), 784 rewriter.getIndexAttr(1)); 785 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 786 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 787 788 return success(); 789 } 790 791 /// Given an OpFoldResult, return a Value. If the OpFoldResult is an Attribute, 792 /// it must be of type Integer. 793 static Value asValue(OpBuilder &builder, Location loc, OpFoldResult ofr) { 794 if (auto val = ofr.dyn_cast<Value>()) 795 return val; 796 auto intVal = getConstantIntValue(ofr); 797 assert(intVal && "expected Value or IntegerAttr"); 798 return builder.create<ConstantIndexOp>(loc, *intVal); 799 } 800 801 /// Given a value, try to extract a constant index-type integer as an Attribute. 802 /// If this fails, return the original value. 803 static OpFoldResult asOpFoldResult(OpBuilder &builder, Value val) { 804 if (auto constInt = getConstantIntValue(val)) 805 return builder.getIndexAttr(*constInt); 806 return val; 807 } 808 809 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 810 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 811 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 812 if (!padOp) 813 return failure(); 814 // Only unit stride supported. 815 if (!sliceOp.hasUnitStride()) 816 return failure(); 817 // Only constant padding value supported. 818 Value padValue = padOp.getConstantPaddingValue(); 819 if (!padValue) 820 return failure(); 821 822 // Helper variables and functions for various arithmetic operations. These are 823 // used extensively for computing new offset/length and padding values. 824 Location loc = sliceOp.getLoc(); 825 AffineExpr dim0, dim1; 826 bindDims(rewriter.getContext(), dim0, dim1); 827 // Add two integers. 828 auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); 829 auto add = [&](Value v1, Value v2) { 830 return rewriter.createOrFold<AffineApplyOp>(loc, addMap, 831 ValueRange{v1, v2}); 832 }; 833 // Subtract two integers. 834 auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); 835 auto sub = [&](Value v1, Value v2) { 836 return rewriter.createOrFold<AffineApplyOp>(loc, subMap, 837 ValueRange{v1, v2}); 838 }; 839 // Take the minimum of two integers. 840 auto idMap = AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); 841 auto min = [&](Value v1, Value v2) { 842 return rewriter.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2}); 843 }; 844 // Take the maximum of two integers. 845 auto max = [&](Value v1, Value v2) { 846 return rewriter.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2}); 847 }; 848 // Zero index-typed integer. 849 auto zero = rewriter.create<ConstantIndexOp>(loc, 0); 850 851 // Helper function for filling static/dynamic low/high padding indices vectors 852 // of PadTensorOp. 853 auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices, 854 SmallVector<int64_t> &staticIndices) { 855 if (auto constInt = getConstantIntValue(val)) { 856 staticIndices.push_back(*constInt); 857 } else { 858 staticIndices.push_back(ShapedType::kDynamicSize); 859 dynIndices.push_back(val); 860 } 861 }; 862 863 // Compute new offsets, lengths, low padding, high padding. 864 SmallVector<OpFoldResult> newOffsets, newLengths, newStrides; 865 SmallVector<Value> newLows, newHighs; 866 SmallVector<int64_t> staticNewLows, staticNewHighs; 867 // Set to true if the original data source is not read at all. 868 bool hasZeroLen = false; 869 // Same as hasZeroLen, but for dynamic dimension sizes. This condition 870 // is true if the original data source turns out to be unused at runtime. 871 Value dynHasZeroLenCond; 872 873 int64_t rank = padOp.getSourceType().getRank(); 874 for (unsigned dim = 0; dim < rank; ++dim) { 875 auto low = asValue(rewriter, loc, padOp.getMixedLowPad()[dim]); 876 auto offset = asValue(rewriter, loc, sliceOp.getMixedOffsets()[dim]); 877 auto length = asValue(rewriter, loc, sliceOp.getMixedSizes()[dim]); 878 auto srcSize = 879 rewriter.createOrFold<tensor::DimOp>(loc, padOp.source(), dim); 880 881 // The new amount of low padding is `low - offset`. Except for the case 882 // where none of the low padding is read. In that case, the new amount of 883 // low padding is zero. 884 Value newLow = max(zero, sub(low, offset)); 885 appendIndex(newLow, newLows, staticNewLows); 886 887 // Start reading the data from position `offset - low`. Since the original 888 // read may have started in the low padding zone, this value could be 889 // negative. Therefore, start reading from: 890 // 891 // max(offset - low, 0) 892 // 893 // The original read could also have started in the high padding zone. 894 // In that case, set the offset to the end of source tensor. The new 895 // ExtractSliceOp length will be zero in that case. (Effectively reading no 896 // data from the source.) 897 Value newOffset = min(max(sub(offset, low), zero), srcSize); 898 newOffsets.push_back(asOpFoldResult(rewriter, newOffset)); 899 900 // The original ExtractSliceOp was reading until position `offset + length`. 901 // Therefore, the corresponding position within the source tensor is: 902 // 903 // offset + length - low 904 // 905 // In case the original ExtractSliceOp stopped reading within the low 906 // padding zone, this value can be negative. In that case, the end position 907 // of the read should be zero. (Similar to newOffset.) 908 // 909 // The original read could also have stopped in the high padding zone. 910 // In that case, set the end positition of the read should be the end of the 911 // source tensor. (Similar to newOffset.) 912 // 913 // endLoc = min(max(offset - low + length, 0), srcSize) 914 // 915 // The new ExtractSliceOp length is `endLoc - newOffset`. 916 Value endLoc = min(max(add(sub(offset, low), length), zero), srcSize); 917 Value newLength = sub(endLoc, newOffset); 918 newLengths.push_back(asOpFoldResult(rewriter, newLength)); 919 920 // Check if newLength is zero. In that case, no SubTensorOp should be 921 // executed. 922 if (auto newLengthInt = getConstantIntValue(newLength)) { 923 hasZeroLen |= *newLengthInt == 0; 924 } else { 925 Value check = rewriter.create<CmpIOp>( 926 loc, CmpIPredicate::eq, newLength, zero); 927 dynHasZeroLenCond = 928 dynHasZeroLenCond 929 ? rewriter.create<OrOp>(loc, check, dynHasZeroLenCond) 930 : check; 931 } 932 933 // The amount of high padding is simply the number of elements remaining, 934 // so that the result has the same length as the original ExtractSliceOp. 935 Value newHigh = sub(sub(length, newLength), newLow); 936 appendIndex(newHigh, newHighs, staticNewHighs); 937 938 // Only unit stride supported. 939 newStrides.push_back(rewriter.getIndexAttr(1)); 940 } 941 942 // Insert cast to ensure that types match. (May be folded away.) 943 auto castResult = [&](Value val) -> Value { 944 auto castOp = rewriter.create<tensor::CastOp>(loc, sliceOp.getType(), val); 945 return castOp; 946 }; 947 948 // In cases where the original data source is unused: Emit a GenerateOp and 949 // do not generate a SliceOp. (The result shape of the SliceOp would 950 // have a dimension of size 0, the semantics of which is unclear.) 951 auto createGenerateOp = [&]() { 952 // The shape of the GenerateOp is the same as the existing SliceOp. 953 RankedTensorType type = sliceOp.getType(); 954 SmallVector<Value> dynDims; 955 for (unsigned i = 0; i < type.getRank(); ++i) { 956 if (type.isDynamicDim(i)) 957 dynDims.push_back(asValue(rewriter, loc, sliceOp.getMixedOffsets()[i])); 958 } 959 960 // Create GenerateOp. 961 auto generateOp = rewriter.create<tensor::GenerateOp>(loc, type, dynDims); 962 963 // Copy region to new op. 964 BlockAndValueMapping bvm; 965 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 966 // Rewrite linalg::YieldOp to tensor::YieldOp. 967 { 968 OpBuilder::InsertionGuard guard(rewriter); 969 auto yieldOp = dyn_cast<linalg::YieldOp>( 970 generateOp.getRegion().front().getTerminator()); 971 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 972 assert(yieldOp.values().size() == 1); 973 rewriter.setInsertionPoint(yieldOp); 974 rewriter.replaceOpWithNewOp<tensor::YieldOp>( 975 yieldOp, yieldOp.values()[0]); 976 } 977 978 return castResult(generateOp); 979 }; 980 981 // Emit a SliceOp and a PadTensorOp. Should not be used in cases where 982 // the result shape of the new SliceOp has a zero dimension. 983 auto createPadTensorOfSubTensor = [&]() { 984 // Create pad_tensor(subtensor(x)). 985 auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>( 986 loc, padOp.source(), newOffsets, newLengths, newStrides); 987 auto newPadTensorOp = rewriter.create<PadTensorOp>( 988 loc, newSliceOp, staticNewLows, staticNewHighs, newLows, newHighs); 989 990 // Copy region to new PadTensorOp. 991 BlockAndValueMapping bvm; 992 padOp.region().cloneInto(&newPadTensorOp.getRegion(), bvm); 993 994 // Cast result and return. 995 return castResult(newPadTensorOp); 996 }; 997 998 // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known 999 // that the original data source x is not used. 1000 if (hasZeroLen) { 1001 rewriter.replaceOp(sliceOp, createGenerateOp()); 1002 return success(); 1003 } 1004 1005 // If there are dynamic dimensions: Generate an scf.if check to avoid creating 1006 // SliceOps with result dimensions of size 0 at runtime. 1007 if (dynHasZeroLenCond) { 1008 auto result = rewriter.create<scf::IfOp>( 1009 loc, sliceOp.getType(), dynHasZeroLenCond, 1010 /*thenBuilder=*/ 1011 [&](OpBuilder &b, Location loc) { 1012 b.create<scf::YieldOp>(loc, createGenerateOp()); 1013 }, 1014 /*elseBuilder=*/ 1015 [&](OpBuilder &b, Location loc) { 1016 b.create<scf::YieldOp>(loc, createPadTensorOfSubTensor()); 1017 }); 1018 rewriter.replaceOp(sliceOp, result.getResult(0)); 1019 return success(); 1020 } 1021 1022 // All shapes are static and the data source is actually used. Rewrite into 1023 // pad_tensor(subtensor(x)). 1024 rewriter.replaceOp(sliceOp, createPadTensorOfSubTensor()); 1025 return success(); 1026 } 1027