1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns/pass to remove usage of unit-extent dimensions 10 // to specify broadcasting in favor of more canonical representation of the 11 // computation 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/BuiltinTypes.h" 25 #include "mlir/Transforms/FoldUtils.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "llvm/Support/CommandLine.h" 28 #include "llvm/Support/Debug.h" 29 30 #define DEBUG_TYPE "linalg-drop-unit-dims" 31 32 using namespace mlir; 33 using namespace mlir::linalg; 34 35 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for 36 /// broadcasting. For example, 37 /// 38 /// ```mlir 39 /// #accesses = [ 40 /// affine_map<(d0, d1) -> (0, d1)>, 41 /// affine_map<(d0, d1) -> (d0, 0)>, 42 /// affine_map<(d0, d1) -> (d0, d1)> 43 /// ] 44 /// 45 /// #trait = { 46 /// args_in = 2, 47 /// args_out = 1, 48 /// indexing_maps = #accesses, 49 /// iterator_types = ["parallel", "parallel"], 50 /// library_call = "some_external_fn" 51 /// } 52 /// 53 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 54 /// tensor<5x5xf32> 55 /// { 56 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : 57 /// tensor<5xf32> into tensor<1x5xf32> 58 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : 59 /// tensor<5xf32> into tensor<5x1xf32> 60 /// %2 = linalg.generic #trait %0, %1 { 61 /// ^bb0(%arg2: f32, %arg3: f32): 62 /// %3 = arith.addf %arg2, %arg3 : f32 63 /// linalg.yield %3 : f32 64 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> 65 /// return %2 : tensor<5x5xf32> 66 /// } 67 /// 68 /// would canonicalize to 69 /// 70 /// ```mlir 71 /// #accesses = [ 72 /// affine_map<(d0, d1) -> (d1)>, 73 /// affine_map<(d0, d1) -> (d0)>, 74 /// affine_map<(d0, d1) -> (d0, d1)> 75 /// ] 76 /// 77 /// #trait = { 78 /// args_in = 2, 79 /// args_out = 1, 80 /// indexing_maps = #accesses, 81 /// iterator_types = ["parallel", "parallel"], 82 /// library_call = "some_external_fn" 83 /// } 84 /// 85 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 86 /// tensor<5x5xf32> 87 /// { 88 /// %0 = linalg.generic #trait %arg0, %arg1 { 89 /// ^bb0(%arg2: f32, %arg3: f32): 90 /// %3 = arith.addf %arg2, %arg3 : f32 91 /// linalg.yield %3 : f32 92 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> 93 /// return %0 : tensor<5x5xf32> 94 /// } 95 96 /// Given dims of the iteration space of a structured op that are known to be 97 /// single trip count (`unitDims`), return the indexing maps to use in the 98 /// canonicalized op with these dims removed, given the original `indexingMaps`. 99 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims, 100 ArrayRef<AffineMap> indexingMaps, 101 MLIRContext *context) { 102 if (indexingMaps.empty()) 103 return nullptr; 104 unsigned numIterationDims = indexingMaps.front().getNumDims(); 105 unsigned numSymbols = indexingMaps.front().getNumSymbols(); 106 107 // Compute the replacement for each dim expr. 108 SmallVector<AffineExpr, 4> dimReplacements; 109 dimReplacements.reserve(numIterationDims); 110 unsigned numKeptDims = 0; 111 for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) { 112 if (unitDims.count(dim)) 113 dimReplacements.push_back(getAffineConstantExpr(0, context)); 114 else 115 dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context)); 116 } 117 118 // Symbols remain the same. 119 SmallVector<AffineExpr, 4> symReplacements; 120 symReplacements.reserve(numSymbols); 121 for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols)) 122 symReplacements.push_back(getAffineSymbolExpr(symbol, context)); 123 124 SmallVector<AffineMap, 4> newIndexingMaps; 125 newIndexingMaps.reserve(indexingMaps.size()); 126 for (AffineMap operandMap : indexingMaps) { 127 // Expected indexing maps to have no symbols. 128 if (operandMap.getNumSymbols()) 129 return nullptr; 130 newIndexingMaps.push_back(simplifyAffineMap( 131 operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements, 132 numIterationDims - unitDims.size(), 133 numSymbols))); 134 } 135 136 // Check that the new index maps are invertible. If not, something went 137 // wrong, so abort. 138 if (!inversePermutation(concatAffineMaps(newIndexingMaps))) 139 return nullptr; 140 return ArrayAttr::get(context, 141 llvm::to_vector<4>(llvm::map_range( 142 newIndexingMaps, [](AffineMap map) -> Attribute { 143 return AffineMapAttr::get(map); 144 }))); 145 } 146 147 /// Update the index accesses of linalg operations having index semantics. 148 static void replaceUnitDimIndexOps(GenericOp genericOp, 149 const DenseSet<unsigned> &unitDims, 150 PatternRewriter &rewriter) { 151 for (IndexOp indexOp : 152 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) { 153 OpBuilder::InsertionGuard guard(rewriter); 154 rewriter.setInsertionPoint(indexOp); 155 if (unitDims.count(indexOp.dim()) != 0) { 156 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0); 157 } else { 158 // Update the dimension of the index operation if needed. 159 unsigned droppedDims = llvm::count_if( 160 unitDims, [&](unsigned dim) { return dim < indexOp.dim(); }); 161 if (droppedDims != 0) 162 rewriter.replaceOpWithNewOp<IndexOp>(indexOp, 163 indexOp.dim() - droppedDims); 164 } 165 } 166 } 167 168 namespace { 169 /// Pattern to fold unit-trip count loops in GenericOps. 170 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> { 171 using OpRewritePattern<GenericOp>::OpRewritePattern; 172 LogicalResult matchAndRewrite(GenericOp genericOp, 173 PatternRewriter &rewriter) const override { 174 SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMapsArray(); 175 if (indexingMaps.empty()) 176 return failure(); 177 178 // Check if any of the iteration dimensions are unit-trip count. They will 179 // end up being unit-trip count if they are used to index into a unit-dim 180 // tensor/memref. 181 AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); 182 if (!invertedMap) 183 return failure(); 184 SmallVector<int64_t> dims = genericOp.getStaticShape(); 185 186 DenseSet<unsigned> unitDims; 187 SmallVector<unsigned, 4> unitDimsReductionLoops; 188 ArrayAttr iteratorTypes = genericOp.iterator_types(); 189 for (const auto &expr : enumerate(invertedMap.getResults())) { 190 if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>()) 191 if (dims[dimExpr.getPosition()] == 1) 192 unitDims.insert(expr.index()); 193 } 194 195 if (unitDims.empty()) 196 return failure(); 197 198 // Compute the modified indexing maps. 199 MLIRContext *context = rewriter.getContext(); 200 ArrayAttr newIndexingMapAttr = 201 replaceUnitDims(unitDims, indexingMaps, context); 202 if (!newIndexingMapAttr) 203 return genericOp.emitError("unable to compute modified indexing_maps"); 204 205 // Compute the iterator types of the modified op by dropping the one-trip 206 // count loops. 207 SmallVector<Attribute, 4> newIteratorTypes; 208 for (const auto &attr : llvm::enumerate(iteratorTypes)) { 209 if (!unitDims.count(attr.index())) 210 newIteratorTypes.push_back(attr.value()); 211 } 212 213 rewriter.startRootUpdate(genericOp); 214 genericOp.indexing_mapsAttr(newIndexingMapAttr); 215 genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes)); 216 replaceUnitDimIndexOps(genericOp, unitDims, rewriter); 217 rewriter.finalizeRootUpdate(genericOp); 218 return success(); 219 } 220 }; 221 222 struct UnitExtentReplacementInfo { 223 Type type; 224 AffineMap indexMap; 225 ArrayAttr reassociation; 226 }; 227 } // namespace 228 229 /// Utility function for replacing operands/results to a linalg generic 230 /// operation with unit-extent dimensions. These can be replaced with 231 /// an operand/result with the unit-extent dimension removed. This is only done 232 /// if the indexing map used to access that didimensionmension has a 233 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a 234 /// Linalg op, and its `indexMap` the utility function returns: 235 /// - the new type with dimensions of size 1 removed. 236 /// - modified index map that can be used to access the replaced result/operand 237 /// - the reassociation that converts from the original tensor type to the 238 /// modified tensor type. 239 static llvm::Optional<UnitExtentReplacementInfo> 240 replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand, 241 MLIRContext *context) { 242 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 243 ArrayRef<int64_t> shape = genericOp.getShape(opOperand); 244 ArrayRef<AffineExpr> exprs = indexingMap.getResults(); 245 SmallVector<AffineExpr> reassociations; 246 SmallVector<Attribute> reassociationMaps; 247 SmallVector<AffineExpr> newIndexExprs; 248 SmallVector<int64_t> newShape; 249 250 int64_t origRank = genericOp.getRank(opOperand); 251 AffineExpr zeroExpr = getAffineConstantExpr(0, context); 252 auto isUnitExtent = [&](int64_t dim) -> bool { 253 return shape[dim] == 1 && exprs[dim] == zeroExpr; 254 }; 255 256 // Early return for memrefs with affine maps to represent that we will always 257 // leave them unchanged. 258 Type actualType = opOperand->get().getType(); 259 if (auto memref = actualType.dyn_cast<MemRefType>()) { 260 if (!memref.getLayout().isIdentity()) 261 return llvm::None; 262 } 263 264 int64_t dim = 0; 265 // Fold dimensions that are unit-extent at the beginning of the tensor. 266 while (dim < origRank && isUnitExtent(dim)) 267 reassociations.push_back(getAffineDimExpr(dim++, context)); 268 while (dim < origRank) { 269 reassociations.push_back(getAffineDimExpr(dim, context)); 270 newIndexExprs.push_back(exprs[dim]); 271 newShape.push_back(shape[dim]); 272 // Fold all following dimensions that are unit-extent. 273 while (dim + 1 < origRank && isUnitExtent(dim + 1)) { 274 ++dim; 275 reassociations.push_back(getAffineDimExpr(dim, context)); 276 } 277 reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get( 278 origRank, /*symbolCount = */ 0, reassociations, context))); 279 reassociations.clear(); 280 ++dim; 281 } 282 283 // Compute the tensor or scalar replacement type. 284 Type elementType = getElementTypeOrSelf(opOperand->get()); 285 Type replacementType; 286 if (elementType == opOperand->get().getType()) { 287 replacementType = elementType; 288 } else if (actualType.isa<RankedTensorType>()) { 289 replacementType = RankedTensorType::get(newShape, elementType); 290 } else if (actualType.isa<MemRefType>()) { 291 replacementType = MemRefType::get(newShape, elementType); 292 } 293 assert(replacementType && "unsupported shaped type"); 294 UnitExtentReplacementInfo info = {replacementType, 295 AffineMap::get(indexingMap.getNumDims(), 296 indexingMap.getNumSymbols(), 297 newIndexExprs, context), 298 ArrayAttr::get(context, reassociationMaps)}; 299 return info; 300 } 301 302 namespace { 303 304 SmallVector<ReassociationExprs, 2> 305 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) { 306 SmallVector<ReassociationExprs, 2> reassociationExprs; 307 for (auto attr : affineMapArrayAttr) 308 reassociationExprs.push_back( 309 llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults())); 310 return reassociationExprs; 311 } 312 313 /// Pattern to replace tensor/buffer operands/results that are unit extents. 314 struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> { 315 using OpRewritePattern<GenericOp>::OpRewritePattern; 316 317 // Return the original value if the type is unchanged, or reshape it. Return a 318 // nullptr if this is an unsupported type. 319 Value maybeExpand(Value result, Type origResultType, 320 ArrayAttr reassociationMap, Location loc, 321 PatternRewriter &rewriter) const { 322 if (origResultType == result.getType()) 323 return result; 324 if (origResultType.isa<RankedTensorType>()) { 325 return rewriter.create<tensor::ExpandShapeOp>( 326 loc, origResultType, result, 327 convertAffineMapArrayToExprs(reassociationMap)); 328 } 329 if (origResultType.isa<MemRefType>()) { 330 return rewriter.create<memref::ExpandShapeOp>( 331 loc, origResultType, result, 332 convertAffineMapArrayToExprs(reassociationMap)); 333 } 334 return nullptr; 335 }; 336 337 // Return the original value if the type is unchanged, or reshape it. Return a 338 // nullptr if this is an unsupported type. 339 Value maybeCollapse(Value operand, Type newInputOutputType, 340 ArrayAttr reassociationMap, Location loc, 341 PatternRewriter &rewriter) const { 342 auto operandType = operand.getType(); 343 if (operandType == newInputOutputType) 344 return operand; 345 if (operandType.isa<MemRefType>()) { 346 return rewriter.create<memref::CollapseShapeOp>( 347 loc, newInputOutputType, operand, 348 convertAffineMapArrayToExprs(reassociationMap)); 349 } 350 if (operandType.isa<RankedTensorType>()) { 351 return rewriter.create<tensor::CollapseShapeOp>( 352 loc, newInputOutputType, operand, 353 convertAffineMapArrayToExprs(reassociationMap)); 354 } 355 return nullptr; 356 }; 357 358 LogicalResult matchAndRewrite(GenericOp genericOp, 359 PatternRewriter &rewriter) const override { 360 // Skip the pattern if the op has any tensor with special encoding. 361 if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) { 362 auto tensorType = type.dyn_cast<RankedTensorType>(); 363 return tensorType && tensorType.getEncoding() != nullptr; 364 })) 365 return failure(); 366 MLIRContext *context = rewriter.getContext(); 367 Location loc = genericOp.getLoc(); 368 369 SmallVector<AffineMap> newIndexingMaps; 370 SmallVector<ArrayAttr> reassociationMaps; 371 SmallVector<Type> newInputOutputTypes; 372 bool doCanonicalization = false; 373 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 374 auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context); 375 if (replacementInfo) { 376 reassociationMaps.push_back(replacementInfo->reassociation); 377 newIndexingMaps.push_back(replacementInfo->indexMap); 378 newInputOutputTypes.push_back(replacementInfo->type); 379 doCanonicalization |= 380 replacementInfo->type != opOperand->get().getType(); 381 } else { 382 // If replaceUnitExtents cannot handle this case, maintain the same 383 // type, indexing map, and create a set of mappings representing an 384 // identity matrix. 385 newInputOutputTypes.push_back(opOperand->get().getType()); 386 newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); 387 int64_t origRank = genericOp.getRank(opOperand); 388 auto maps = llvm::to_vector<8>(llvm::map_range( 389 llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute { 390 return AffineMapAttr::get( 391 AffineMap::get(origRank, /*symbolCount = */ 0, 392 getAffineDimExpr(dim, context), context)); 393 })); 394 reassociationMaps.push_back(ArrayAttr::get(context, maps)); 395 } 396 } 397 398 // If the indexing maps of the result operation are not invertible (i.e. not 399 // legal), abort. 400 if (!doCanonicalization || 401 !inversePermutation(concatAffineMaps(newIndexingMaps))) 402 return failure(); 403 404 // If any operand type change, insert a reshape to convert from the original 405 // type to the new type. 406 // TODO: get rid of flattenedIdx which assumes operand order and contiguity. 407 unsigned flattenedIdx = 0; 408 auto insertReshapes = [&](ValueRange values) { 409 SmallVector<Value, 4> res; 410 res.reserve(values.size()); 411 for (auto operand : values) { 412 auto reshapedValue = 413 maybeCollapse(operand, newInputOutputTypes[flattenedIdx], 414 reassociationMaps[flattenedIdx], loc, rewriter); 415 assert(reshapedValue && 416 "expected ranked MemRef or Tensor operand type"); 417 res.push_back(reshapedValue); 418 ++flattenedIdx; 419 } 420 return res; 421 }; 422 423 SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs()); 424 SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs()); 425 426 // If any result type changes, insert a reshape to convert from the original 427 // type to the new type. 428 SmallVector<Type, 4> resultTypes; 429 resultTypes.reserve(genericOp.getNumResults()); 430 for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults())) 431 resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]); 432 GenericOp replacementOp = rewriter.create<GenericOp>( 433 loc, resultTypes, newInputs, newOutputs, newIndexingMaps, 434 llvm::to_vector<4>( 435 genericOp.iterator_types().template getAsValueRange<StringAttr>())); 436 rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(), 437 replacementOp.region().begin()); 438 439 // If any result tensor has a modified shape, then add reshape to recover 440 // the original shape. 441 SmallVector<Value, 4> resultReplacements; 442 for (const auto &result : llvm::enumerate(replacementOp.getResults())) { 443 unsigned index = result.index() + replacementOp.getNumInputs(); 444 auto origResultType = genericOp.getResult(result.index()).getType(); 445 446 auto newResult = maybeExpand(result.value(), origResultType, 447 reassociationMaps[index], loc, rewriter); 448 assert(newResult && 449 "unexpected output type other than ranked MemRef or Tensor"); 450 resultReplacements.push_back(newResult); 451 } 452 rewriter.replaceOp(genericOp, resultReplacements); 453 return success(); 454 } 455 }; 456 } // namespace 457 458 namespace { 459 /// Convert `extract_slice` operations to rank-reduced versions. 460 struct RankReducedExtractSliceOp 461 : public OpRewritePattern<tensor::ExtractSliceOp> { 462 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; 463 464 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, 465 PatternRewriter &rewriter) const override { 466 RankedTensorType resultType = sliceOp.getType(); 467 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); 468 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); 469 SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides(); 470 auto reassociation = getReassociationMapForFoldingUnitDims(sizes); 471 if (!reassociation || 472 reassociation->size() == static_cast<size_t>(resultType.getRank())) 473 return failure(); 474 auto rankReducedType = 475 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( 476 reassociation->size(), sliceOp.getSourceType(), offsets, sizes, 477 strides) 478 .cast<RankedTensorType>(); 479 480 Location loc = sliceOp.getLoc(); 481 Value newSlice = rewriter.create<tensor::ExtractSliceOp>( 482 loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides); 483 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( 484 sliceOp, resultType, newSlice, *reassociation); 485 return success(); 486 } 487 }; 488 489 /// Convert `insert_slice` operations to rank-reduced versions. 490 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp. 491 template <typename InsertOpTy> 492 struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> { 493 using OpRewritePattern<InsertOpTy>::OpRewritePattern; 494 495 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp, 496 PatternRewriter &rewriter) const override { 497 RankedTensorType sourceType = insertSliceOp.getSourceType(); 498 SmallVector<OpFoldResult> offsets = insertSliceOp.getMixedOffsets(); 499 SmallVector<OpFoldResult> sizes = insertSliceOp.getMixedSizes(); 500 SmallVector<OpFoldResult> strides = insertSliceOp.getMixedStrides(); 501 auto reassociation = getReassociationMapForFoldingUnitDims(sizes); 502 if (!reassociation || 503 reassociation->size() == static_cast<size_t>(sourceType.getRank())) 504 return failure(); 505 Location loc = insertSliceOp.getLoc(); 506 tensor::CollapseShapeOp reshapedSource; 507 { 508 OpBuilder::InsertionGuard g(rewriter); 509 // The only difference between InsertSliceOp and ParallelInsertSliceOp is 510 // the the insertion point is just before the ParallelCombiningOp in the 511 // parallel case. 512 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value) 513 rewriter.setInsertionPoint(insertSliceOp->getParentOp()); 514 reshapedSource = rewriter.create<tensor::CollapseShapeOp>( 515 loc, insertSliceOp.getSource(), *reassociation); 516 } 517 rewriter.replaceOpWithNewOp<InsertOpTy>( 518 insertSliceOp, reshapedSource, insertSliceOp.getDest(), 519 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), 520 insertSliceOp.getMixedStrides()); 521 return success(); 522 } 523 }; 524 } // namespace 525 526 /// Patterns that are used to canonicalize the use of unit-extent dims for 527 /// broadcasting. 528 void mlir::linalg::populateFoldUnitExtentDimsPatterns( 529 RewritePatternSet &patterns) { 530 auto *context = patterns.getContext(); 531 patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, RankReducedExtractSliceOp, 532 RankReducedInsertSliceOp<tensor::InsertSliceOp>, 533 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>( 534 context); 535 linalg::FillOp::getCanonicalizationPatterns(patterns, context); 536 linalg::InitTensorOp::getCanonicalizationPatterns(patterns, context); 537 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); 538 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); 539 } 540 541 namespace { 542 /// Pass that removes unit-extent dims within generic ops. 543 struct LinalgFoldUnitExtentDimsPass 544 : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> { 545 void runOnOperation() override { 546 Operation *op = getOperation(); 547 MLIRContext *context = op->getContext(); 548 RewritePatternSet patterns(context); 549 if (foldOneTripLoopsOnly) 550 patterns.add<FoldUnitDimLoops>(context); 551 else 552 populateFoldUnitExtentDimsPatterns(patterns); 553 (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); 554 } 555 }; 556 } // namespace 557 558 std::unique_ptr<Pass> mlir::createLinalgFoldUnitExtentDimsPass() { 559 return std::make_unique<LinalgFoldUnitExtentDimsPass>(); 560 } 561