1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns/pass to remove usage of unit-extent dimensions 10 // to specify broadcasting in favor of more canonical representation of the 11 // computation 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/BuiltinTypes.h" 24 #include "mlir/Transforms/FoldUtils.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 #include "llvm/Support/CommandLine.h" 27 #include "llvm/Support/Debug.h" 28 29 #define DEBUG_TYPE "linalg-drop-unit-dims" 30 31 using namespace mlir; 32 using namespace mlir::linalg; 33 34 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for 35 /// broadcasting. For example, 36 /// 37 /// ```mlir 38 /// #accesses = [ 39 /// affine_map<(d0, d1) -> (0, d1)>, 40 /// affine_map<(d0, d1) -> (d0, 0)>, 41 /// affine_map<(d0, d1) -> (d0, d1)> 42 /// ] 43 /// 44 /// #trait = { 45 /// args_in = 2, 46 /// args_out = 1, 47 /// indexing_maps = #accesses, 48 /// iterator_types = ["parallel", "parallel"], 49 /// library_call = "some_external_fn" 50 /// } 51 /// 52 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 53 /// tensor<5x5xf32> 54 /// { 55 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : 56 /// tensor<5xf32> into tensor<1x5xf32> 57 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : 58 /// tensor<5xf32> into tensor<5x1xf32> 59 /// %2 = linalg.generic #trait %0, %1 { 60 /// ^bb0(%arg2: f32, %arg3: f32): 61 /// %3 = addf %arg2, %arg3 : f32 62 /// linalg.yield %3 : f32 63 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> 64 /// return %2 : tensor<5x5xf32> 65 /// } 66 /// 67 /// would canonicalize to 68 /// 69 /// ```mlir 70 /// #accesses = [ 71 /// affine_map<(d0, d1) -> (d1)>, 72 /// affine_map<(d0, d1) -> (d0)>, 73 /// affine_map<(d0, d1) -> (d0, d1)> 74 /// ] 75 /// 76 /// #trait = { 77 /// args_in = 2, 78 /// args_out = 1, 79 /// indexing_maps = #accesses, 80 /// iterator_types = ["parallel", "parallel"], 81 /// library_call = "some_external_fn" 82 /// } 83 /// 84 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 85 /// tensor<5x5xf32> 86 /// { 87 /// %0 = linalg.generic #trait %arg0, %arg1 { 88 /// ^bb0(%arg2: f32, %arg3: f32): 89 /// %3 = addf %arg2, %arg3 : f32 90 /// linalg.yield %3 : f32 91 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> 92 /// return %0 : tensor<5x5xf32> 93 /// } 94 95 /// Given dims of the iteration space of a structured op that are known to be 96 /// single trip count (`unitDims`), return the indexing maps to use in the 97 /// canonicalized op with these dims removed, given the original `indexingMaps`. 98 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims, 99 ArrayRef<AffineMap> indexingMaps, 100 MLIRContext *context) { 101 if (indexingMaps.empty()) 102 return nullptr; 103 unsigned numIterationDims = indexingMaps.front().getNumDims(); 104 unsigned numSymbols = indexingMaps.front().getNumSymbols(); 105 106 // Compute the replacement for each dim expr. 107 SmallVector<AffineExpr, 4> dimReplacements; 108 dimReplacements.reserve(numIterationDims); 109 unsigned numKeptDims = 0; 110 for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) { 111 if (unitDims.count(dim)) 112 dimReplacements.push_back(getAffineConstantExpr(0, context)); 113 else 114 dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context)); 115 } 116 117 // Symbols remain the same. 118 SmallVector<AffineExpr, 4> symReplacements; 119 symReplacements.reserve(numSymbols); 120 for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols)) 121 symReplacements.push_back(getAffineSymbolExpr(symbol, context)); 122 123 SmallVector<AffineMap, 4> newIndexingMaps; 124 newIndexingMaps.reserve(indexingMaps.size()); 125 for (AffineMap operandMap : indexingMaps) { 126 // Expected indexing maps to have no symbols. 127 if (operandMap.getNumSymbols()) 128 return nullptr; 129 newIndexingMaps.push_back(simplifyAffineMap( 130 operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements, 131 numIterationDims - unitDims.size(), 132 numSymbols))); 133 } 134 135 // Check that the new index maps are invertible. If not, something went 136 // wrong, so abort. 137 if (!inversePermutation(concatAffineMaps(newIndexingMaps))) 138 return nullptr; 139 return ArrayAttr::get(context, 140 llvm::to_vector<4>(llvm::map_range( 141 newIndexingMaps, [](AffineMap map) -> Attribute { 142 return AffineMapAttr::get(map); 143 }))); 144 } 145 146 /// Update the index accesses of linalg operations having index semantics. 147 static void replaceUnitDimIndexOps(GenericOp genericOp, 148 const DenseSet<unsigned> &unitDims, 149 PatternRewriter &rewriter) { 150 assert(genericOp->getNumRegions() == 1 && 151 genericOp->getRegion(0).getBlocks().size() == 1 && 152 "expected generic operation to have one block."); 153 Block &block = genericOp->getRegion(0).front(); 154 155 for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps<IndexOp>())) { 156 OpBuilder::InsertionGuard guard(rewriter); 157 rewriter.setInsertionPoint(indexOp); 158 if (unitDims.count(indexOp.dim()) != 0) { 159 rewriter.replaceOpWithNewOp<ConstantIndexOp>(indexOp, 0); 160 } else { 161 // Update the dimension of the index operation if needed. 162 unsigned droppedDims = llvm::count_if( 163 unitDims, [&](unsigned dim) { return dim < indexOp.dim(); }); 164 if (droppedDims != 0) 165 rewriter.replaceOpWithNewOp<IndexOp>(indexOp, 166 indexOp.dim() - droppedDims); 167 } 168 } 169 } 170 171 namespace { 172 /// Pattern to fold unit-trip count loops in GenericOps. 173 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> { 174 using OpRewritePattern<GenericOp>::OpRewritePattern; 175 LogicalResult matchAndRewrite(GenericOp genericOp, 176 PatternRewriter &rewriter) const override { 177 SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps(); 178 if (indexingMaps.empty()) 179 return failure(); 180 181 // Check if any of the iteration dimensions are unit-trip count. They will 182 // end up being unit-trip count if they are used to index into a unit-dim 183 // tensor/memref. 184 AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); 185 if (!invertedMap) 186 return failure(); 187 SmallVector<int64_t> dims = genericOp.getStaticShape(); 188 189 // Find all the reduction iterators. Those need some special consideration 190 // (see below). 191 auto getLoopDimsOfType = 192 [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> { 193 SmallVector<AffineExpr> dimExprs; 194 getDimsOfType(genericOp, iteratorTypeName, dimExprs); 195 return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) { 196 return expr.cast<AffineDimExpr>().getPosition(); 197 })); 198 }; 199 auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName()); 200 201 DenseSet<unsigned> unitDims; 202 SmallVector<unsigned, 4> unitDimsReductionLoops; 203 ArrayAttr iteratorTypes = genericOp.iterator_types(); 204 for (auto expr : enumerate(invertedMap.getResults())) { 205 if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>()) 206 if (dims[dimExpr.getPosition()] == 1) { 207 if (isParallelIterator(iteratorTypes[expr.index()])) 208 unitDims.insert(expr.index()); 209 else if (isReductionIterator(iteratorTypes[expr.index()])) 210 unitDimsReductionLoops.push_back(expr.index()); 211 } 212 } 213 214 // Reduction loops can be dropped if there is at least one other reduction 215 // loop that is not dropped. This accounts for the initial value read in the 216 // reduction loop. 217 if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) { 218 if (unitDimsReductionLoops.size() == reductionDims.size()) 219 unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end())); 220 else 221 unitDims.insert(unitDimsReductionLoops.begin(), 222 unitDimsReductionLoops.end()); 223 } 224 225 if (unitDims.empty()) 226 return failure(); 227 228 // Compute the modified indexing maps. 229 MLIRContext *context = rewriter.getContext(); 230 ArrayAttr newIndexingMapAttr = 231 replaceUnitDims(unitDims, indexingMaps, context); 232 if (!newIndexingMapAttr) 233 return genericOp.emitError("unable to compute modified indexing_maps"); 234 235 // Compute the iterator types of the modified op by dropping the one-trip 236 // count loops. 237 SmallVector<Attribute, 4> newIteratorTypes; 238 for (auto attr : llvm::enumerate(iteratorTypes)) { 239 if (!unitDims.count(attr.index())) 240 newIteratorTypes.push_back(attr.value()); 241 } 242 243 rewriter.startRootUpdate(genericOp); 244 genericOp.indexing_mapsAttr(newIndexingMapAttr); 245 genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes)); 246 replaceUnitDimIndexOps(genericOp, unitDims, rewriter); 247 rewriter.finalizeRootUpdate(genericOp); 248 return success(); 249 } 250 }; 251 252 struct UnitExtentReplacementInfo { 253 Type type; 254 AffineMap indexMap; 255 ArrayAttr reassociation; 256 }; 257 } // namespace 258 259 /// Utility function for replacing operands/results to a linalg generic 260 /// operation with unit-extent dimensions. These can be replaced with 261 /// an operand/result with the unit-extent dimension removed. This is only done 262 /// if the indexing map used to access that didimensionmension has a 263 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a 264 /// Linalg op, and its `indexMap` the utility function returns: 265 /// - the new type with dimensions of size 1 removed. 266 /// - modified index map that can be used to access the replaced result/operand 267 /// - the reassociation that converts from the original tensor type to the 268 /// modified tensor type. 269 static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp, 270 OpOperand *opOperand, 271 MLIRContext *context) { 272 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 273 ArrayRef<int64_t> shape = genericOp.getShape(opOperand); 274 ArrayRef<AffineExpr> exprs = indexingMap.getResults(); 275 SmallVector<AffineExpr> reassociations; 276 SmallVector<Attribute> reassociationMaps; 277 SmallVector<AffineExpr> newIndexExprs; 278 SmallVector<int64_t> newShape; 279 280 int64_t origRank = genericOp.getRank(opOperand); 281 AffineExpr zeroExpr = getAffineConstantExpr(0, context); 282 auto isUnitExtent = [&](int64_t dim) -> bool { 283 return shape[dim] == 1 && exprs[dim] == zeroExpr; 284 }; 285 286 int64_t dim = 0; 287 // Fold dimensions that are unit-extent at the beginning of the tensor. 288 while (dim < origRank && isUnitExtent(dim)) 289 reassociations.push_back(getAffineDimExpr(dim++, context)); 290 while (dim < origRank) { 291 reassociations.push_back(getAffineDimExpr(dim, context)); 292 newIndexExprs.push_back(exprs[dim]); 293 newShape.push_back(shape[dim]); 294 // Fold all following dimensions that are unit-extent. 295 while (dim + 1 < origRank && isUnitExtent(dim + 1)) { 296 ++dim; 297 reassociations.push_back(getAffineDimExpr(dim, context)); 298 } 299 reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get( 300 origRank, /*symbolCount = */ 0, reassociations, context))); 301 reassociations.clear(); 302 ++dim; 303 } 304 // Compute the tensor or scalar replacement type. 305 Type actualType = opOperand->get().getType(); 306 Type elementType = getElementTypeOrSelf(opOperand->get()); 307 Type replacementType; 308 if (elementType == opOperand->get().getType()) { 309 replacementType = elementType; 310 } else if (actualType.isa<RankedTensorType>()) { 311 replacementType = RankedTensorType::get(newShape, elementType); 312 } else if (actualType.isa<MemRefType>()) { 313 assert(actualType.cast<MemRefType>().getAffineMaps().empty() && 314 "unsupported strided memrefs"); 315 replacementType = MemRefType::get(newShape, elementType); 316 } 317 assert(replacementType && "unsupported shaped type"); 318 UnitExtentReplacementInfo info = {replacementType, 319 AffineMap::get(indexingMap.getNumDims(), 320 indexingMap.getNumSymbols(), 321 newIndexExprs, context), 322 ArrayAttr::get(context, reassociationMaps)}; 323 return info; 324 } 325 326 namespace { 327 328 SmallVector<ReassociationExprs, 2> 329 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) { 330 SmallVector<ReassociationExprs, 2> reassociationExprs; 331 for (auto attr : affineMapArrayAttr) 332 reassociationExprs.push_back( 333 llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults())); 334 return reassociationExprs; 335 } 336 337 /// Pattern to replace tensor/buffer operands/results that are unit extents. 338 struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> { 339 using OpRewritePattern<GenericOp>::OpRewritePattern; 340 341 // Return the original value if the type is unchanged, or reshape it. Return a 342 // nullptr if this is an unsupported type. 343 Value maybeExpand(Value result, Type origResultType, 344 ArrayAttr reassociationMap, Location loc, 345 PatternRewriter &rewriter) const { 346 if (origResultType == result.getType()) 347 return result; 348 if (origResultType.isa<RankedTensorType>()) { 349 return rewriter.create<linalg::TensorExpandShapeOp>( 350 loc, origResultType, result, 351 convertAffineMapArrayToExprs(reassociationMap)); 352 } 353 if (origResultType.isa<MemRefType>()) { 354 return rewriter.create<linalg::ExpandShapeOp>( 355 loc, origResultType, result, 356 convertAffineMapArrayToExprs(reassociationMap)); 357 } 358 return nullptr; 359 }; 360 361 // Return the original value if the type is unchanged, or reshape it. Return a 362 // nullptr if this is an unsupported type. 363 Value maybeCollapse(Value operand, Type newInputOutputType, 364 ArrayAttr reassociationMap, Location loc, 365 PatternRewriter &rewriter) const { 366 auto operandType = operand.getType(); 367 if (operandType == newInputOutputType) 368 return operand; 369 if (operandType.isa<MemRefType>()) { 370 return rewriter.create<linalg::CollapseShapeOp>( 371 loc, newInputOutputType, operand, 372 convertAffineMapArrayToExprs(reassociationMap)); 373 } 374 if (operandType.isa<RankedTensorType>()) { 375 return rewriter.create<linalg::TensorCollapseShapeOp>( 376 loc, newInputOutputType, operand, 377 convertAffineMapArrayToExprs(reassociationMap)); 378 } 379 return nullptr; 380 }; 381 382 LogicalResult matchAndRewrite(GenericOp genericOp, 383 PatternRewriter &rewriter) const override { 384 MLIRContext *context = rewriter.getContext(); 385 Location loc = genericOp.getLoc(); 386 387 SmallVector<AffineMap> newIndexingMaps; 388 SmallVector<ArrayAttr> reassociationMaps; 389 SmallVector<Type> newInputOutputTypes; 390 bool doCanonicalization = false; 391 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 392 UnitExtentReplacementInfo replacementInfo = 393 replaceUnitExtents(genericOp, opOperand, context); 394 reassociationMaps.push_back(replacementInfo.reassociation); 395 newIndexingMaps.push_back(replacementInfo.indexMap); 396 newInputOutputTypes.push_back(replacementInfo.type); 397 doCanonicalization |= replacementInfo.type != opOperand->get().getType(); 398 } 399 400 // If the indexing maps of the result operation are not invertible (i.e. not 401 // legal), abort. 402 if (!doCanonicalization || 403 !inversePermutation(concatAffineMaps(newIndexingMaps))) 404 return failure(); 405 406 // If any operand type change, insert a reshape to convert from the original 407 // type to the new type. 408 // TODO: get rid of flattenedIdx which assumes operand order and contiguity. 409 unsigned flattenedIdx = 0; 410 auto insertReshapes = [&](ValueRange values) { 411 SmallVector<Value, 4> res; 412 res.reserve(values.size()); 413 for (auto operand : values) { 414 auto reshapedValue = 415 maybeCollapse(operand, newInputOutputTypes[flattenedIdx], 416 reassociationMaps[flattenedIdx], loc, rewriter); 417 assert(reshapedValue && 418 "expected ranked MemRef or Tensor operand type"); 419 res.push_back(reshapedValue); 420 ++flattenedIdx; 421 } 422 return res; 423 }; 424 425 SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs()); 426 SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs()); 427 428 // If any result type changes, insert a reshape to convert from the original 429 // type to the new type. 430 SmallVector<Type, 4> resultTypes; 431 resultTypes.reserve(genericOp.getNumResults()); 432 for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults())) 433 resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]); 434 GenericOp replacementOp = rewriter.create<GenericOp>( 435 loc, resultTypes, newInputs, newOutputs, newIndexingMaps, 436 llvm::to_vector<4>( 437 genericOp.iterator_types().template getAsValueRange<StringAttr>())); 438 rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(), 439 replacementOp.region().begin()); 440 441 // If any result tensor has a modified shape, then add reshape to recover 442 // the original shape. 443 SmallVector<Value, 4> resultReplacements; 444 for (auto result : llvm::enumerate(replacementOp.getResults())) { 445 unsigned index = result.index() + replacementOp.getNumInputs(); 446 auto origResultType = genericOp.getResult(result.index()).getType(); 447 448 auto newResult = maybeExpand(result.value(), origResultType, 449 reassociationMaps[index], loc, rewriter); 450 assert(newResult && 451 "unexpected output type other than ranked MemRef or Tensor"); 452 resultReplacements.push_back(newResult); 453 } 454 rewriter.replaceOp(genericOp, resultReplacements); 455 return success(); 456 } 457 }; 458 } // namespace 459 460 /// Get the reassociation maps to fold the result of a subtensor (or source of a 461 /// subtensor_insert) operation with given offsets, and sizes to its 462 /// rank-reduced version. This is only done for the cases where the size is 1 463 /// and offset is 0. Strictly speaking the offset 0 is not required in general, 464 /// but non-zero offsets are not handled by SPIR-V backend at this point (and 465 /// potentially cannot be handled). 466 static Optional<SmallVector<ReassociationIndices>> 467 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) { 468 SmallVector<ReassociationIndices> reassociation; 469 ReassociationIndices curr; 470 for (auto it : llvm::enumerate(mixedSizes)) { 471 auto dim = it.index(); 472 auto size = it.value(); 473 curr.push_back(dim); 474 auto attr = size.dyn_cast<Attribute>(); 475 if (attr && attr.cast<IntegerAttr>().getInt() == 1) 476 continue; 477 reassociation.emplace_back(ReassociationIndices{}); 478 std::swap(reassociation.back(), curr); 479 } 480 // When the reassociations are not empty, then fold the remaining 481 // unit-dimensions into the last dimension. If the reassociations so far is 482 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. 483 if (!curr.empty() && !reassociation.empty()) 484 reassociation.back().append(curr.begin(), curr.end()); 485 return reassociation; 486 } 487 488 namespace { 489 /// Convert `subtensor` operations to rank-reduced versions. 490 struct UseRankReducedSubTensorOp : public OpRewritePattern<SubTensorOp> { 491 using OpRewritePattern<SubTensorOp>::OpRewritePattern; 492 493 LogicalResult matchAndRewrite(SubTensorOp subTensorOp, 494 PatternRewriter &rewriter) const override { 495 RankedTensorType resultType = subTensorOp.getType(); 496 SmallVector<OpFoldResult> offsets = subTensorOp.getMixedOffsets(); 497 SmallVector<OpFoldResult> sizes = subTensorOp.getMixedSizes(); 498 SmallVector<OpFoldResult> strides = subTensorOp.getMixedStrides(); 499 auto reassociation = getReassociationMapForFoldingUnitDims(sizes); 500 if (!reassociation || 501 reassociation->size() == static_cast<size_t>(resultType.getRank())) 502 return failure(); 503 auto rankReducedType = 504 SubTensorOp::inferRankReducedResultType(reassociation->size(), 505 subTensorOp.getSourceType(), 506 offsets, sizes, strides) 507 .cast<RankedTensorType>(); 508 509 Location loc = subTensorOp.getLoc(); 510 Value newSubTensor = rewriter.create<SubTensorOp>( 511 loc, rankReducedType, subTensorOp.source(), offsets, sizes, strides); 512 rewriter.replaceOpWithNewOp<TensorExpandShapeOp>( 513 subTensorOp, resultType, newSubTensor, *reassociation); 514 return success(); 515 } 516 }; 517 518 /// Convert `subtensor_insert` operations to rank-reduced versions. 519 struct UseRankReducedSubTensorInsertOp 520 : public OpRewritePattern<SubTensorInsertOp> { 521 using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern; 522 523 LogicalResult matchAndRewrite(SubTensorInsertOp insertOp, 524 PatternRewriter &rewriter) const override { 525 RankedTensorType sourceType = insertOp.getSourceType(); 526 SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets(); 527 SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes(); 528 SmallVector<OpFoldResult> strides = insertOp.getMixedStrides(); 529 auto reassociation = getReassociationMapForFoldingUnitDims(sizes); 530 if (!reassociation || 531 reassociation->size() == static_cast<size_t>(sourceType.getRank())) 532 return failure(); 533 Location loc = insertOp.getLoc(); 534 auto reshapedSource = rewriter.create<TensorCollapseShapeOp>( 535 loc, insertOp.source(), *reassociation); 536 rewriter.replaceOpWithNewOp<SubTensorInsertOp>( 537 insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(), 538 insertOp.getMixedSizes(), insertOp.getMixedStrides()); 539 return success(); 540 } 541 }; 542 } // namespace 543 544 /// Patterns that are used to canonicalize the use of unit-extent dims for 545 /// broadcasting. 546 void mlir::linalg::populateFoldUnitExtentDimsPatterns( 547 RewritePatternSet &patterns) { 548 auto *context = patterns.getContext(); 549 patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, UseRankReducedSubTensorOp, 550 UseRankReducedSubTensorInsertOp>(context); 551 TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context); 552 TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context); 553 } 554 555 namespace { 556 /// Pass that removes unit-extent dims within generic ops. 557 struct LinalgFoldUnitExtentDimsPass 558 : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> { 559 void runOnFunction() override { 560 FuncOp funcOp = getFunction(); 561 MLIRContext *context = funcOp.getContext(); 562 RewritePatternSet patterns(context); 563 if (foldOneTripLoopsOnly) 564 patterns.add<FoldUnitDimLoops>(context); 565 else 566 populateFoldUnitExtentDimsPatterns(patterns); 567 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 568 } 569 }; 570 } // namespace 571 572 std::unique_ptr<OperationPass<FuncOp>> 573 mlir::createLinalgFoldUnitExtentDimsPass() { 574 return std::make_unique<LinalgFoldUnitExtentDimsPass>(); 575 } 576