1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/SCF/Transforms.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Utils/StaticValueUtils.h" 24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 25 #include "mlir/Dialect/Vector/VectorOps.h" 26 #include "mlir/IR/AffineExpr.h" 27 #include "mlir/IR/Matchers.h" 28 #include "mlir/Pass/Pass.h" 29 #include "mlir/Support/LLVM.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "llvm/ADT/ScopeExit.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/raw_ostream.h" 35 #include <type_traits> 36 37 #define DEBUG_TYPE "linalg-transforms" 38 39 using namespace mlir; 40 using namespace mlir::linalg; 41 42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 43 44 //===----------------------------------------------------------------------===// 45 // Transformations exposed as rewrite patterns. 46 //===----------------------------------------------------------------------===// 47 // Marker used as attribute name in generated Linalg rewriting transformations. 48 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 49 "__internal_linalg_transform__"; 50 51 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 52 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 53 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 54 replacement(replacement), matchByDefault(false) {} 55 56 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 57 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 58 Optional<Identifier> replacement) 59 : filters(), 60 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 61 replacement(replacement), matchByDefault(false) { 62 if (f) 63 filters.push_back(f); 64 } 65 66 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 67 PatternRewriter &rewriter, Operation *op) const { 68 if (llvm::any_of(filters, 69 [&](const FilterFunction &f) { return failed(f(op)); })) 70 return failure(); 71 72 auto attr = op->template getAttrOfType<StringAttr>( 73 LinalgTransforms::kLinalgTransformMarker); 74 75 if (!attr) { 76 // 1. Has no filter case and matchDisjunction is empty. 77 if (matchDisjunction.empty() || matchByDefault) 78 return success(); 79 80 // 2. Has no filter but was expecting a filter. 81 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 82 diag << " does not have any filter from list: "; 83 interleaveComma(matchDisjunction, diag); 84 }); 85 } 86 87 // 4. Match explicit filter. 88 for (auto filter : matchDisjunction) 89 if (attr.getValue() == filter) 90 return success(); 91 92 // 5. Fail to match. 93 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 94 diag << " does not have any filter from list: "; 95 interleaveComma(matchDisjunction, diag); 96 }); 97 } 98 99 void mlir::linalg::LinalgTransformationFilter:: 100 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 101 Operation *op) const { 102 if (replacement.hasValue()) 103 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 104 rewriter.getStringAttr(replacement.getValue().strref())); 105 else 106 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 107 rewriter.getContext())); 108 } 109 110 LinalgTilingOptions & 111 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 112 assert(!tileSizeComputationFunction && "tile sizes already set"); 113 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 114 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 115 OpBuilder::InsertionGuard guard(b); 116 b.setInsertionPointToStart( 117 &op->getParentOfType<FuncOp>().getBody().front()); 118 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 119 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 120 return v; 121 })); 122 }; 123 return *this; 124 } 125 126 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 127 assert(!tileSizeComputationFunction && "tile sizes already set"); 128 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 129 SmallVector<Value, 4> tileSizes; 130 auto linalgOp = dyn_cast<LinalgOp>(op); 131 if (!linalgOp) 132 return tileSizes; 133 Location loc = linalgOp.getLoc(); 134 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 135 AffineMap map = linalgOp.getShapesToLoopsMap(); 136 if (!map) 137 return tileSizes; 138 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 139 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 140 // size 0). 141 for (Value shapeSize : shapeSizes) 142 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 143 ? b.create<arith::ConstantIndexOp>(loc, 0) 144 : b.create<arith::ConstantIndexOp>(loc, 1)); 145 return tileSizes; 146 }; 147 return *this; 148 } 149 150 /// Helper function that tries to pad `opOperand`. Exit early and return success 151 /// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to 152 /// pad the operand even if it already has a static shape. Set `result` to the 153 /// result of the created PadTensorOp or return failure if the operand cannot be 154 /// padded to a static shape. 155 static LogicalResult padOperandToSmallestStaticBoundingBox( 156 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 157 const PaddingValueComputationFunction &paddingFunc, 158 const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { 159 // Can't pad scalars. 160 if (opToPad.getShape(opOperand).empty()) 161 return success(); 162 // Can't pad if no padding value is known. 163 FailureOr<Value> paddingValue = paddingFunc(b, *opOperand); 164 if (failed(paddingValue)) 165 return success(); 166 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 167 // Not a slice op, cannot construct a static bounding box. 168 if (!sliceOp) 169 return failure(); 170 SmallVector<int64_t> staticSizes; 171 staticSizes.reserve(opToPad.getRank(opOperand)); 172 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 173 for (auto size : shapedOp.getMixedSizes()) { 174 auto indexAttr = size.is<Attribute>() 175 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 176 : linalg::getSmallestBoundingIndex(size.get<Value>()); 177 // SmallestBoundingIndex must exist for all sizes. 178 // For now return an error if we can't find it. 179 if (!indexAttr) { 180 LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 181 return failure(); 182 } 183 staticSizes.push_back(indexAttr.getInt()); 184 } 185 auto staticTensorType = RankedTensorType::get( 186 staticSizes, getElementTypeOrSelf(opOperand->get())); 187 bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; 188 result = linalg::PadTensorOp::createPadHighOp( 189 staticTensorType, opOperand->get(), paddingValue.getValue(), 190 /*nofold=*/nofold, opToPad->getLoc(), b); 191 return success(); 192 } 193 194 FailureOr<SmallVector<Value>> 195 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 196 const PaddingValueComputationFunction &paddingFunc, 197 const PaddingNoFoldComputationFunction &nofoldFunc, 198 LinalgOp &paddedOp) { 199 Location loc = opToPad->getLoc(); 200 201 // TODO: there are cases where we may still want to pad to larger sizes. 202 assert(opToPad.hasTensorSemantics() && 203 "expected operation to have tensor semantics"); 204 205 OpBuilder::InsertionGuard g(b); 206 // Set IP after op because we also take the dims of the original output. 207 b.setInsertionPointAfter(opToPad); 208 // Make a copy of the shaped operands and update it. 209 SmallVector<Value> newOperands; 210 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 211 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 212 Value paddedOperand; 213 // If padding was requested but the shape cannot be bounded statically then 214 // the pattern fails to apply. 215 if (failed(padOperandToSmallestStaticBoundingBox( 216 b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand))) 217 return failure(); 218 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 219 } 220 221 SmallVector<SmallVector<Value>> reifiedResultShapes; 222 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 223 .reifyResultShapes(b, reifiedResultShapes))) 224 return failure(); 225 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 226 "expected same number of results"); 227 228 // Clone `opToPad` to operate on the statically padded shapes. 229 auto resultTensorTypes = 230 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 231 paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 232 233 // Recover the slice out of the new static results. This keeps the original 234 // linalg op around because it uses the dims of the original results. 235 SmallVector<Value> paddedSubviewResults; 236 paddedSubviewResults.reserve(opToPad->getNumResults()); 237 for (auto en : llvm::enumerate(paddedOp->getResults())) { 238 Value paddedResult = en.value(); 239 int64_t resultNumber = en.index(); 240 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 241 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 242 SmallVector<OpFoldResult> sizes; 243 for (Value v : reifiedResultShapes[resultNumber]) 244 sizes.push_back(v); 245 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 246 paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 247 loc, paddedResult, offsets, sizes, strides)); 248 } 249 return paddedSubviewResults; 250 } 251 252 /// Linalg base tiling pattern. 253 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 254 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 255 LinalgTransformationFilter filter, PatternBenefit benefit) 256 : RewritePattern(opName, benefit, context), filter(filter), 257 options(options) {} 258 259 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 260 MLIRContext *context, LinalgTilingOptions options, 261 LinalgTransformationFilter filter, PatternBenefit benefit) 262 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 263 options(options) {} 264 265 /// Try to peel a loop `op` and return the new result. 266 // TODO: Add support for scf.parallel and affine.for loops. 267 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 268 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 269 .Case<scf::ForOp>([&](scf::ForOp forOp) { 270 scf::ForOp partialIteration; 271 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 272 partialIteration))) 273 return partialIteration->getResults(); 274 assert(!partialIteration && "expected that loop was not peeled"); 275 return forOp->getResults(); 276 }) 277 .Default([&](Operation *op) { return op->getResults(); }); 278 } 279 280 /// Try to peel a TiledLoopOp and return the new result. 281 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, 282 TiledLoopOp tiledLoop, int64_t idx) { 283 assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) && 284 "requested peeling of non-existing loop"); 285 TiledLoopOp result; 286 if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result))) 287 return result->getResults(); 288 assert(!result && "expected that loop was not peeled"); 289 return tiledLoop->getResults(); 290 } 291 292 /// Peel loops after tiling. 293 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res, 294 const LinalgTilingOptions &options) { 295 for (int64_t loop : options.peeledLoops) { 296 assert(loop < static_cast<int64_t>(res.loops.size()) && 297 "requested peeling of non-existing loop"); 298 SmallVector<Value, 4> loopResults; 299 Operation *loopOp = res.loops[loop]; 300 if (options.loopType == LinalgTilingLoopType::TiledLoops) { 301 assert(llvm::all_of( 302 res.loops, 303 [&](Operation *op) { return op == res.loops.front(); }) && 304 "expected that all loop ops are the same TiledLoopOp"); 305 auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp); 306 assert(tiledLoopOp && "expected TiledLoopOp"); 307 loopResults = peelLoop(rewriter, tiledLoopOp, loop); 308 } else { 309 loopResults = peelLoop(rewriter, loopOp); 310 } 311 312 // The result of the loop nest may change with peeling. 313 if (res.tensorResults.size() == loopOp->getNumResults() && 314 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 315 loopOp->getResults().begin())) 316 res.tensorResults = loopResults; 317 } 318 } 319 320 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 321 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 322 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 323 if (!linalgOp) 324 return failure(); 325 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 326 return failure(); 327 328 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 329 330 if (!res) 331 return failure(); 332 // Clear filter to stop recursive pattern application. 333 filter.replaceLinalgTransformationFilter(rewriter, res->op); 334 335 // Peel loops. 336 peelLoops(rewriter, *res, options); 337 338 // Consider padding on the fly only if the op has tensor semantics. 339 if (!options.paddingValueComputationFunction || 340 !linalgOp.hasTensorSemantics()) { 341 result = *res; 342 return success(); 343 } 344 345 // Try to pad on the fly by rewriting res->op as a padded op. If successful, 346 // `res.op` is rewritten in static form with padded operands. 347 LinalgOp paddedOp; 348 FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp( 349 rewriter, res->op, options.paddingValueComputationFunction, 350 options.paddingNoFoldComputationFunction, paddedOp); 351 if (succeeded(newResults)) { 352 rewriter.replaceOp(res->op, newResults.getValue()); 353 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 354 res->op = paddedOp; 355 result = *res; 356 // Do not perform replacement of `linalgOp`, let the derived patterns 357 // do this as they see fit, from the resulting TiledLinalgOp. 358 return success(); 359 } 360 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 361 return failure(); 362 } 363 364 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 365 if (tiledOp.loops.empty()) 366 return tiledOp.op.getOperation()->getResults(); 367 return tiledOp.loops.front()->getResults(); 368 } 369 370 static ValueRange 371 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 372 if (tiledAndFusedOp.fusedLoops.empty()) 373 return tiledAndFusedOp.op.getOperation()->getResults(); 374 return tiledAndFusedOp.fusedLoops.front()->getResults(); 375 } 376 377 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 378 StringRef opName, MLIRContext *context, 379 const LinalgDependenceGraph &dependenceGraph, 380 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 381 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 382 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 383 : RewritePattern(opName, benefit, context, {}), 384 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 385 fusionOptions(fusionOptions), filter(filter), 386 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 387 388 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 389 Operation *op, PatternRewriter &rewriter) const { 390 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 391 // TODO: remove hasIndexSemantics check once index ops are supported. 392 if (!linalgOp || linalgOp.hasIndexSemantics()) 393 return failure(); 394 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 395 return failure(); 396 397 DenseSet<Operation *> producers; 398 producers.insert(linalgOp); 399 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 400 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 401 // When looking at dependences into, indexingOp is always OpOperand. We 402 // could assert, but continue if this is not the case. 403 if (!operandNumber) 404 continue; 405 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 406 continue; 407 if (isa<LinalgOp>(dependence.getDependentOp())) 408 producers.insert(dependence.getDependentOp()); 409 } 410 411 SmallVector<LinalgOp, 1> fusionOps; 412 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 413 ++it) { 414 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 415 if (producerLinalgOp && producers.count(producerLinalgOp)) 416 fusionOps.push_back(producerLinalgOp); 417 } 418 fusionOps.push_back(linalgOp); 419 420 SmallVector<Value, 4> tileSizes = 421 tilingOptions.tileSizeComputationFunction(rewriter, op); 422 LinalgTilingOptions instanceTilingOptions = tilingOptions; 423 instanceTilingOptions.setTileSizes(tileSizes); 424 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 425 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 426 if (!tiledAndFusedOps) 427 return failure(); 428 429 // Tile the unfused loops; 430 SmallVector<Value, 4> unfusedLoopTileSizes; 431 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 432 for (auto tileSize : enumerate(tileSizes)) { 433 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 434 unfusedLoopTileSizes.push_back(zero); 435 else 436 unfusedLoopTileSizes.push_back(tileSize.value()); 437 } 438 // Tile the loop only if there is a non-zero tile size. 439 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 440 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 441 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 442 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 443 return cst.value() != 0; 444 return true; 445 })) { 446 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 447 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 448 Optional<TiledLinalgOp> unfusedTiledOp = 449 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 450 if (!unfusedTiledOp) 451 return failure(); 452 rewriter.replaceOp(tiledAndFusedOps->op, 453 getTiledOpResult(unfusedTiledOp.getValue())); 454 tiledAndFusedOps->op = unfusedTiledOp->op; 455 } 456 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 457 458 filter.replaceLinalgTransformationFilter(rewriter, 459 tiledAndFusedOps->op.getOperation()); 460 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 461 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 462 fusedOp.getOperation()); 463 } 464 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 465 originalOpMarker.replaceLinalgTransformationFilter( 466 rewriter, origProducerOp.getOperation()); 467 } 468 rewriter.updateRootInPlace(op, [&]() { 469 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 470 }); 471 return success(); 472 } 473 474 /// Linalg padding pattern. 475 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 476 MLIRContext *context, LinalgPaddingOptions options, 477 LinalgTransformationFilter filter, PatternBenefit benefit) 478 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 479 options(options) {} 480 481 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 482 StringRef opName, MLIRContext *context, LinalgPaddingOptions options, 483 LinalgTransformationFilter filter, PatternBenefit benefit) 484 : RewritePattern(opName, benefit, context, {}), filter(filter), 485 options(options) {} 486 487 LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite( 488 Operation *op, PatternRewriter &rewriter) const { 489 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 490 if (!linalgOp) 491 return failure(); 492 if (!linalgOp.hasTensorSemantics()) 493 return failure(); 494 if (failed(filter.checkAndNotify(rewriter, op))) 495 return failure(); 496 497 // Pad the operation. 498 LinalgOp paddedOp; 499 FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp( 500 rewriter, linalgOp, options.paddingValueComputationFunction, 501 options.paddingNoFoldComputationFunction, paddedOp); 502 if (failed(newResults)) 503 return failure(); 504 505 // Compute the desired hoisting depths. 506 SmallVector<int64_t> depths; 507 if (options.paddingHoistComputationFunction) { 508 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) 509 depths.push_back(options.paddingHoistComputationFunction(*opOperand)); 510 } 511 512 // Hoist the padding. 513 for (auto en : enumerate(depths)) { 514 OpOperand &opOperand = paddedOp->getOpOperand(en.index()); 515 auto padTensorOp = opOperand.get().getDefiningOp<PadTensorOp>(); 516 if (!padTensorOp || en.value() == 0) 517 continue; 518 PadTensorOp hoistedOp; 519 FailureOr<Value> newResult = 520 hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp); 521 if (failed(newResult)) 522 continue; 523 rewriter.replaceOp(padTensorOp, newResult.getValue()); 524 } 525 526 // Replace the original operation to pad. 527 rewriter.replaceOp(op, newResults.getValue()); 528 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 529 return success(); 530 } 531 532 /// Linalg generic interchange pattern. 533 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 534 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 535 LinalgTransformationFilter filter, PatternBenefit benefit) 536 : OpRewritePattern(context, benefit), filter(filter), 537 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 538 539 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 540 GenericOp genericOp, PatternRewriter &rewriter) const { 541 if (failed(filter.checkAndNotify(rewriter, genericOp))) 542 return failure(); 543 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 544 return failure(); 545 546 // TODO: figure out how this interplays with named ops. In particular this 547 // should break the named op property. 548 rewriter.updateRootInPlace(genericOp, [&]() { 549 interchangeGenericOp(rewriter, genericOp, interchangeVector); 550 // New filter if specified. 551 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 552 }); 553 return success(); 554 } 555 556 /// Linalg generalization pattern. 557 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 558 MLIRContext *context, LinalgTransformationFilter filter, 559 PatternBenefit benefit) 560 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 561 562 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 563 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 564 PatternBenefit benefit) 565 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 566 567 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite( 568 Operation *op, PatternRewriter &rewriter) const { 569 if (failed(filter.checkAndNotify(rewriter, op))) 570 return failure(); 571 if (failed(generalizeNamedOpPrecondition(op))) 572 return failure(); 573 574 GenericOp genericOp = generalizeNamedOp(rewriter, op); 575 rewriter.replaceOp(op, genericOp.getResults()); 576 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 577 return success(); 578 } 579 580 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 581 MLIRContext *context, LinalgTransformationFilter filter, 582 LinalgPromotionOptions options, PatternBenefit benefit) 583 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 584 options(options) {} 585 586 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 587 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 588 LinalgTransformationFilter filter, PatternBenefit benefit) 589 : RewritePattern(opName, benefit, context, {}), filter(filter), 590 options(options) {} 591 592 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 593 Operation *op, PatternRewriter &rewriter) const { 594 if (failed(filter.checkAndNotify(rewriter, op))) 595 return failure(); 596 if (failed(promoteSubviewsPrecondition(op, options))) 597 return failure(); 598 599 // TODO: We cannot use root update here. This pattern is creating other ops, 600 // so if the promotion fails, those need to be cleaned up, which doesnt seem 601 // to be happening here. So to fail properly, we should be cloning the op and 602 // deleting the previous op. This needs more investigation. 603 rewriter.startRootUpdate(op); 604 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 605 if (!promotedOp) { 606 rewriter.cancelRootUpdate(op); 607 return op->emitError("subview promotion failed"); 608 } 609 rewriter.finalizeRootUpdate(op); 610 filter.replaceLinalgTransformationFilter(rewriter, op); 611 return success(); 612 } 613 614 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 615 MLIRContext *context, LinalgTransformationFilter filter, 616 PatternBenefit benefit) 617 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 618 619 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 620 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 621 PatternBenefit benefit) 622 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 623 624 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 625 Operation *op, PatternRewriter &rewriter) const { 626 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 627 if (!linalgOp) 628 return failure(); 629 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 630 return failure(); 631 SmallVector<Value> newResults; 632 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 633 return failure(); 634 if (!newResults.empty()) 635 rewriter.replaceOp(op, newResults); 636 else 637 rewriter.eraseOp(op); 638 return success(); 639 } 640 641 LogicalResult mlir::linalg::applyStagedPatterns( 642 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 643 const FrozenRewritePatternSet &stage2Patterns, 644 function_ref<LogicalResult(Operation *)> stage3Lambda) { 645 unsigned iteration = 0; 646 (void)iteration; 647 for (const auto &patterns : stage1Patterns) { 648 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 649 << *op); 650 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 651 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 652 return failure(); 653 } 654 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 655 << *op); 656 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 657 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 658 return failure(); 659 } 660 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 661 << *op); 662 if (stage3Lambda) { 663 if (failed(stage3Lambda(op))) 664 return failure(); 665 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 666 << *op); 667 } 668 } 669 return success(); 670 } 671 672 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 673 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 674 } 675 676 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 677 /// with pad_val) and GenericOp (to copy contents). 678 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 679 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 680 681 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 682 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 683 684 // Bail on non-static shapes. 685 if (!inputShapedType.hasStaticShape()) 686 return failure(); 687 if (!resultShapedType.hasStaticShape()) 688 return failure(); 689 690 // Only support padding with a constant for now, i.e. either: 691 // 1. A BBarg from a different block. 692 // 2. A value defined outside of the current block. 693 Block &block = padOp.region().front(); 694 auto yieldOp = cast<YieldOp>(block.getTerminator()); 695 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 696 Value padValue = yieldOp.values().front(); 697 Operation *definingOp = padValue.getDefiningOp(); 698 if (definingOp && definingOp->getBlock() == &block) 699 return failure(); 700 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 701 return failure(); 702 703 // Create tensor with the padded shape 704 Location loc = padOp.getLoc(); 705 SmallVector<Value> indices(resultShapedType.getRank(), 706 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 707 Value initTensor = rewriter.create<InitTensorOp>( 708 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 709 710 // Initialize tensor with the pad value 711 Value tmpTensor = 712 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 713 714 // Copy original contents into new tensor 715 // Uses linalg.generic, but could be done with tensor.insert_slice 716 SmallVector<AffineExpr, 4> outputExprs; 717 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 718 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 719 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 720 } 721 722 SmallVector<AffineMap, 2> transferMaps = { 723 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 724 AffineMap::get(resultShapedType.getRank(), 725 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 726 727 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 728 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 729 getNParallelLoopsAttrs(resultShapedType.getRank()), 730 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 731 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 732 }); 733 734 return success(); 735 } 736 737 /// Filling `dest` using FillOp constant padding value if possible. 738 /// Otherwise, generate a tensor::GenerateOp. 739 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 740 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 741 const SmallVector<Value> &dynSizes) const { 742 auto padValue = padOp.getConstantPaddingValue(); 743 if (padValue) 744 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 745 746 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 747 auto generateOp = rewriter.create<tensor::GenerateOp>( 748 padOp.getLoc(), padOp.getResultType(), dynSizes); 749 // Copy region to new op. 750 BlockAndValueMapping bvm; 751 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 752 // Rewrite linalg::YieldOp to tensor::YieldOp. 753 OpBuilder::InsertionGuard guard(rewriter); 754 auto yieldOp = 755 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 756 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 757 assert(yieldOp.values().size() == 1); 758 rewriter.setInsertionPoint(yieldOp); 759 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 760 return generateOp; 761 } 762 763 LogicalResult 764 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 765 PatternRewriter &rewriter) const { 766 // Given an OpFoldResult, return an index-typed value. 767 auto getIdxValue = [&](OpFoldResult ofr) { 768 if (auto val = ofr.dyn_cast<Value>()) 769 return val; 770 return rewriter 771 .create<arith::ConstantIndexOp>( 772 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 773 .getResult(); 774 }; 775 776 auto resultType = padOp.getResultType(); 777 // Compute size of InitTensorOp. Any combination of static/dynamic is 778 // supported. 779 SmallVector<Value> dynSizes; 780 SmallVector<int64_t> staticSizes; 781 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 782 if (resultType.isDynamicDim(dim)) { 783 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 784 padOp.source(), dim); 785 // Add low and high padding value. 786 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 787 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 788 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 789 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 790 dynSizes.push_back(plusHigh); 791 } 792 staticSizes.push_back(resultType.getDimSize(dim)); 793 } 794 795 // Init tensor and fill it with padding. 796 Value init = rewriter.create<InitTensorOp>( 797 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 798 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 799 800 // Try optimize the copy of source. 801 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 802 return success(); 803 804 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 805 // for copying the PadOp source. 806 auto sourceType = padOp.getSourceType(); 807 // Compute size of source of PadTensorOp. 808 SmallVector<OpFoldResult> srcSizes; 809 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 810 if (sourceType.isDynamicDim(dim)) { 811 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 812 padOp.getLoc(), padOp.source(), dim)); 813 } else { 814 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 815 } 816 } 817 // Strides of InsertSliceOp are all 1. 818 SmallVector<OpFoldResult> strides(sourceType.getRank(), 819 rewriter.getIndexAttr(1)); 820 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 821 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 822 823 return success(); 824 } 825 826 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 827 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 828 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 829 if (!padOp) 830 return failure(); 831 // Only unit stride supported. 832 if (!sliceOp.hasUnitStride()) 833 return failure(); 834 835 Operation *tiledPadOp = padOp.getTiledImplementation( 836 rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 837 sliceOp.getMixedSizes()); 838 // All shapes are static and the data source is actually used. Rewrite into 839 // pad_tensor(subtensor(x)). 840 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 841 return success(); 842 } 843 844 namespace { 845 // The following are patterns for downscaling convolution ops with size-1 846 // window dimensions. 847 // 848 // Note that we'd eventually want to write such transformations in a generic 849 // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 850 // and then turning back to named ops. But for now it's fine to have a few 851 // patterns matching special ops to get started. 852 853 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D 854 /// convolution ops. 855 struct DownscaleSizeOneWindowed2DConvolution final 856 : public OpRewritePattern<Conv2DNhwcHwcfOp> { 857 using OpRewritePattern::OpRewritePattern; 858 859 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, 860 PatternRewriter &rewriter) const override { 861 auto linalgOp = cast<linalg::LinalgOp>(*convOp); 862 if (linalgOp.hasBufferSemantics()) 863 return failure(); // To be implemented 864 865 Value input = convOp.inputs().front(); 866 Value filter = convOp.inputs().back(); 867 Value output = convOp.outputs().front(); 868 869 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 870 auto filterType = filter.getType().dyn_cast<RankedTensorType>(); 871 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 872 873 auto inputShape = inputType.getShape(); 874 auto filterShape = filterType.getShape(); 875 auto outputShape = outputType.getShape(); 876 877 // Only handle the case where at least one of the window dimensions is 878 // of size 1. Other cases can rely on tiling to reduce to such cases. 879 int64_t fhSize = filterShape[0], fwSize = filterShape[1]; 880 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 881 if (!(fhSize == 1 && ohSize == 1) && !(fwSize == 1 && owSize == 1)) 882 return failure(); 883 bool removeH = ohSize == 1; 884 885 // Get new shapes and types for all operands by removing the size-1 886 // dimension. 887 888 SmallVector<int64_t, 3> newInputShape{ 889 inputShape[0], inputShape[removeH ? 2 : 1], inputShape[3]}; 890 auto newInputType = RankedTensorType::get( 891 newInputShape, inputType.getElementType(), inputType.getEncoding()); 892 893 SmallVector<int64_t, 3> newFilterShape{filterShape[removeH ? 1 : 0], 894 filterShape[2], filterShape[3]}; 895 auto newFilterType = RankedTensorType::get( 896 newFilterShape, filterType.getElementType(), filterType.getEncoding()); 897 898 SmallVector<int64_t, 3> newOutputShape{ 899 outputShape[0], outputShape[removeH ? 2 : 1], outputShape[3]}; 900 auto newOutputType = RankedTensorType::get( 901 newOutputShape, outputType.getElementType(), outputType.getEncoding()); 902 903 SmallVector<ReassociationIndices, 3> ioReshapeIndices = {{0}, {1, 2}, {3}}; 904 SmallVector<ReassociationIndices, 3> fReshapeIndices = {{0, 1}, {2}, {3}}; 905 906 // Reshape all operands for 1-D convolution. 907 Location loc = convOp.getLoc(); 908 Value newInput = rewriter.create<linalg::TensorCollapseShapeOp>( 909 loc, newInputType, input, ioReshapeIndices); 910 Value newFilter = rewriter.create<linalg::TensorCollapseShapeOp>( 911 loc, newFilterType, filter, fReshapeIndices); 912 Value newOutput = rewriter.create<linalg::TensorCollapseShapeOp>( 913 loc, newOutputType, output, ioReshapeIndices); 914 915 // We need to shrink the strides and dilations too. 916 auto stride = convOp.strides().getFlatValue<int64_t>(removeH ? 1 : 0); 917 auto stridesAttr = rewriter.getI64VectorAttr(stride); 918 auto dilation = convOp.dilations().getFlatValue<int64_t>(removeH ? 1 : 0); 919 auto dilationsAttr = rewriter.getI64VectorAttr(dilation); 920 921 auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( 922 loc, newOutputType, ValueRange{newInput, newFilter}, 923 ValueRange{newOutput}, stridesAttr, dilationsAttr); 924 925 rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>( 926 convOp, outputType, conv1DOp.getResult(0), ioReshapeIndices); 927 return success(); 928 }; 929 }; 930 931 } // namespace 932 933 void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, 934 PatternBenefit benefit) { 935 patterns.add<DownscaleSizeOneWindowed2DConvolution>(patterns.getContext(), 936 benefit); 937 } 938