1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns/pass to remove usage of unit-extent dimensions 10 // to specify broadcasting in favor of more canonical representation of the 11 // computation 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/BuiltinTypes.h" 26 #include "mlir/Transforms/FoldUtils.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/Support/CommandLine.h" 29 #include "llvm/Support/Debug.h" 30 31 #define DEBUG_TYPE "linalg-drop-unit-dims" 32 33 using namespace mlir; 34 using namespace mlir::linalg; 35 36 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for 37 /// broadcasting. For example, 38 /// 39 /// ```mlir 40 /// #accesses = [ 41 /// affine_map<(d0, d1) -> (0, d1)>, 42 /// affine_map<(d0, d1) -> (d0, 0)>, 43 /// affine_map<(d0, d1) -> (d0, d1)> 44 /// ] 45 /// 46 /// #trait = { 47 /// args_in = 2, 48 /// args_out = 1, 49 /// indexing_maps = #accesses, 50 /// iterator_types = ["parallel", "parallel"], 51 /// library_call = "some_external_fn" 52 /// } 53 /// 54 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 55 /// tensor<5x5xf32> 56 /// { 57 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : 58 /// tensor<5xf32> into tensor<1x5xf32> 59 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : 60 /// tensor<5xf32> into tensor<5x1xf32> 61 /// %2 = linalg.generic #trait %0, %1 { 62 /// ^bb0(%arg2: f32, %arg3: f32): 63 /// %3 = arith.addf %arg2, %arg3 : f32 64 /// linalg.yield %3 : f32 65 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> 66 /// return %2 : tensor<5x5xf32> 67 /// } 68 /// 69 /// would canonicalize to 70 /// 71 /// ```mlir 72 /// #accesses = [ 73 /// affine_map<(d0, d1) -> (d1)>, 74 /// affine_map<(d0, d1) -> (d0)>, 75 /// affine_map<(d0, d1) -> (d0, d1)> 76 /// ] 77 /// 78 /// #trait = { 79 /// args_in = 2, 80 /// args_out = 1, 81 /// indexing_maps = #accesses, 82 /// iterator_types = ["parallel", "parallel"], 83 /// library_call = "some_external_fn" 84 /// } 85 /// 86 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 87 /// tensor<5x5xf32> 88 /// { 89 /// %0 = linalg.generic #trait %arg0, %arg1 { 90 /// ^bb0(%arg2: f32, %arg3: f32): 91 /// %3 = arith.addf %arg2, %arg3 : f32 92 /// linalg.yield %3 : f32 93 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> 94 /// return %0 : tensor<5x5xf32> 95 /// } 96 97 /// Given dims of the iteration space of a structured op that are known to be 98 /// single trip count (`unitDims`), return the indexing maps to use in the 99 /// canonicalized op with these dims removed, given the original `indexingMaps`. 100 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims, 101 ArrayRef<AffineMap> indexingMaps, 102 MLIRContext *context) { 103 if (indexingMaps.empty()) 104 return nullptr; 105 unsigned numIterationDims = indexingMaps.front().getNumDims(); 106 unsigned numSymbols = indexingMaps.front().getNumSymbols(); 107 108 // Compute the replacement for each dim expr. 109 SmallVector<AffineExpr, 4> dimReplacements; 110 dimReplacements.reserve(numIterationDims); 111 unsigned numKeptDims = 0; 112 for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) { 113 if (unitDims.count(dim)) 114 dimReplacements.push_back(getAffineConstantExpr(0, context)); 115 else 116 dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context)); 117 } 118 119 // Symbols remain the same. 120 SmallVector<AffineExpr, 4> symReplacements; 121 symReplacements.reserve(numSymbols); 122 for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols)) 123 symReplacements.push_back(getAffineSymbolExpr(symbol, context)); 124 125 SmallVector<AffineMap, 4> newIndexingMaps; 126 newIndexingMaps.reserve(indexingMaps.size()); 127 for (AffineMap operandMap : indexingMaps) { 128 // Expected indexing maps to have no symbols. 129 if (operandMap.getNumSymbols()) 130 return nullptr; 131 newIndexingMaps.push_back(simplifyAffineMap( 132 operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements, 133 numIterationDims - unitDims.size(), 134 numSymbols))); 135 } 136 137 // Check that the new index maps are invertible. If not, something went 138 // wrong, so abort. 139 if (!inversePermutation(concatAffineMaps(newIndexingMaps))) 140 return nullptr; 141 return ArrayAttr::get(context, 142 llvm::to_vector<4>(llvm::map_range( 143 newIndexingMaps, [](AffineMap map) -> Attribute { 144 return AffineMapAttr::get(map); 145 }))); 146 } 147 148 /// Update the index accesses of linalg operations having index semantics. 149 static void replaceUnitDimIndexOps(GenericOp genericOp, 150 const DenseSet<unsigned> &unitDims, 151 PatternRewriter &rewriter) { 152 assert(genericOp->getNumRegions() == 1 && 153 genericOp->getRegion(0).getBlocks().size() == 1 && 154 "expected generic operation to have one block."); 155 Block &block = genericOp->getRegion(0).front(); 156 157 for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps<IndexOp>())) { 158 OpBuilder::InsertionGuard guard(rewriter); 159 rewriter.setInsertionPoint(indexOp); 160 if (unitDims.count(indexOp.dim()) != 0) { 161 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0); 162 } else { 163 // Update the dimension of the index operation if needed. 164 unsigned droppedDims = llvm::count_if( 165 unitDims, [&](unsigned dim) { return dim < indexOp.dim(); }); 166 if (droppedDims != 0) 167 rewriter.replaceOpWithNewOp<IndexOp>(indexOp, 168 indexOp.dim() - droppedDims); 169 } 170 } 171 } 172 173 namespace { 174 /// Pattern to fold unit-trip count loops in GenericOps. 175 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> { 176 using OpRewritePattern<GenericOp>::OpRewritePattern; 177 LogicalResult matchAndRewrite(GenericOp genericOp, 178 PatternRewriter &rewriter) const override { 179 SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps(); 180 if (indexingMaps.empty()) 181 return failure(); 182 183 // Check if any of the iteration dimensions are unit-trip count. They will 184 // end up being unit-trip count if they are used to index into a unit-dim 185 // tensor/memref. 186 AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); 187 if (!invertedMap) 188 return failure(); 189 SmallVector<int64_t> dims = genericOp.getStaticShape(); 190 191 DenseSet<unsigned> unitDims; 192 SmallVector<unsigned, 4> unitDimsReductionLoops; 193 ArrayAttr iteratorTypes = genericOp.iterator_types(); 194 for (auto expr : enumerate(invertedMap.getResults())) { 195 if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>()) 196 if (dims[dimExpr.getPosition()] == 1) 197 unitDims.insert(expr.index()); 198 } 199 200 if (unitDims.empty()) 201 return failure(); 202 203 // Compute the modified indexing maps. 204 MLIRContext *context = rewriter.getContext(); 205 ArrayAttr newIndexingMapAttr = 206 replaceUnitDims(unitDims, indexingMaps, context); 207 if (!newIndexingMapAttr) 208 return genericOp.emitError("unable to compute modified indexing_maps"); 209 210 // Compute the iterator types of the modified op by dropping the one-trip 211 // count loops. 212 SmallVector<Attribute, 4> newIteratorTypes; 213 for (auto attr : llvm::enumerate(iteratorTypes)) { 214 if (!unitDims.count(attr.index())) 215 newIteratorTypes.push_back(attr.value()); 216 } 217 218 rewriter.startRootUpdate(genericOp); 219 genericOp.indexing_mapsAttr(newIndexingMapAttr); 220 genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes)); 221 replaceUnitDimIndexOps(genericOp, unitDims, rewriter); 222 rewriter.finalizeRootUpdate(genericOp); 223 return success(); 224 } 225 }; 226 227 struct UnitExtentReplacementInfo { 228 Type type; 229 AffineMap indexMap; 230 ArrayAttr reassociation; 231 }; 232 } // namespace 233 234 /// Utility function for replacing operands/results to a linalg generic 235 /// operation with unit-extent dimensions. These can be replaced with 236 /// an operand/result with the unit-extent dimension removed. This is only done 237 /// if the indexing map used to access that didimensionmension has a 238 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a 239 /// Linalg op, and its `indexMap` the utility function returns: 240 /// - the new type with dimensions of size 1 removed. 241 /// - modified index map that can be used to access the replaced result/operand 242 /// - the reassociation that converts from the original tensor type to the 243 /// modified tensor type. 244 static llvm::Optional<UnitExtentReplacementInfo> 245 replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand, 246 MLIRContext *context) { 247 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 248 ArrayRef<int64_t> shape = genericOp.getShape(opOperand); 249 ArrayRef<AffineExpr> exprs = indexingMap.getResults(); 250 SmallVector<AffineExpr> reassociations; 251 SmallVector<Attribute> reassociationMaps; 252 SmallVector<AffineExpr> newIndexExprs; 253 SmallVector<int64_t> newShape; 254 255 int64_t origRank = genericOp.getRank(opOperand); 256 AffineExpr zeroExpr = getAffineConstantExpr(0, context); 257 auto isUnitExtent = [&](int64_t dim) -> bool { 258 return shape[dim] == 1 && exprs[dim] == zeroExpr; 259 }; 260 261 // Early return for memrefs with affine maps to represent that we will always 262 // leave them unchanged. 263 Type actualType = opOperand->get().getType(); 264 if (auto memref = actualType.dyn_cast<MemRefType>()) { 265 if (!memref.getAffineMaps().empty()) 266 return llvm::None; 267 } 268 269 int64_t dim = 0; 270 // Fold dimensions that are unit-extent at the beginning of the tensor. 271 while (dim < origRank && isUnitExtent(dim)) 272 reassociations.push_back(getAffineDimExpr(dim++, context)); 273 while (dim < origRank) { 274 reassociations.push_back(getAffineDimExpr(dim, context)); 275 newIndexExprs.push_back(exprs[dim]); 276 newShape.push_back(shape[dim]); 277 // Fold all following dimensions that are unit-extent. 278 while (dim + 1 < origRank && isUnitExtent(dim + 1)) { 279 ++dim; 280 reassociations.push_back(getAffineDimExpr(dim, context)); 281 } 282 reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get( 283 origRank, /*symbolCount = */ 0, reassociations, context))); 284 reassociations.clear(); 285 ++dim; 286 } 287 288 // Compute the tensor or scalar replacement type. 289 Type elementType = getElementTypeOrSelf(opOperand->get()); 290 Type replacementType; 291 if (elementType == opOperand->get().getType()) { 292 replacementType = elementType; 293 } else if (actualType.isa<RankedTensorType>()) { 294 replacementType = RankedTensorType::get(newShape, elementType); 295 } else if (actualType.isa<MemRefType>()) { 296 replacementType = MemRefType::get(newShape, elementType); 297 } 298 assert(replacementType && "unsupported shaped type"); 299 UnitExtentReplacementInfo info = {replacementType, 300 AffineMap::get(indexingMap.getNumDims(), 301 indexingMap.getNumSymbols(), 302 newIndexExprs, context), 303 ArrayAttr::get(context, reassociationMaps)}; 304 return info; 305 } 306 307 namespace { 308 309 SmallVector<ReassociationExprs, 2> 310 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) { 311 SmallVector<ReassociationExprs, 2> reassociationExprs; 312 for (auto attr : affineMapArrayAttr) 313 reassociationExprs.push_back( 314 llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults())); 315 return reassociationExprs; 316 } 317 318 /// Pattern to replace tensor/buffer operands/results that are unit extents. 319 struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> { 320 using OpRewritePattern<GenericOp>::OpRewritePattern; 321 322 // Return the original value if the type is unchanged, or reshape it. Return a 323 // nullptr if this is an unsupported type. 324 Value maybeExpand(Value result, Type origResultType, 325 ArrayAttr reassociationMap, Location loc, 326 PatternRewriter &rewriter) const { 327 if (origResultType == result.getType()) 328 return result; 329 if (origResultType.isa<RankedTensorType>()) { 330 return rewriter.create<linalg::TensorExpandShapeOp>( 331 loc, origResultType, result, 332 convertAffineMapArrayToExprs(reassociationMap)); 333 } 334 if (origResultType.isa<MemRefType>()) { 335 return rewriter.create<memref::ExpandShapeOp>( 336 loc, origResultType, result, 337 convertAffineMapArrayToExprs(reassociationMap)); 338 } 339 return nullptr; 340 }; 341 342 // Return the original value if the type is unchanged, or reshape it. Return a 343 // nullptr if this is an unsupported type. 344 Value maybeCollapse(Value operand, Type newInputOutputType, 345 ArrayAttr reassociationMap, Location loc, 346 PatternRewriter &rewriter) const { 347 auto operandType = operand.getType(); 348 if (operandType == newInputOutputType) 349 return operand; 350 if (operandType.isa<MemRefType>()) { 351 return rewriter.create<memref::CollapseShapeOp>( 352 loc, newInputOutputType, operand, 353 convertAffineMapArrayToExprs(reassociationMap)); 354 } 355 if (operandType.isa<RankedTensorType>()) { 356 return rewriter.create<linalg::TensorCollapseShapeOp>( 357 loc, newInputOutputType, operand, 358 convertAffineMapArrayToExprs(reassociationMap)); 359 } 360 return nullptr; 361 }; 362 363 LogicalResult matchAndRewrite(GenericOp genericOp, 364 PatternRewriter &rewriter) const override { 365 // Skip the pattern if the op has any tensor with special encoding. 366 if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) { 367 auto tensorType = type.dyn_cast<RankedTensorType>(); 368 return tensorType && tensorType.getEncoding() != nullptr; 369 })) 370 return failure(); 371 MLIRContext *context = rewriter.getContext(); 372 Location loc = genericOp.getLoc(); 373 374 SmallVector<AffineMap> newIndexingMaps; 375 SmallVector<ArrayAttr> reassociationMaps; 376 SmallVector<Type> newInputOutputTypes; 377 bool doCanonicalization = false; 378 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 379 auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context); 380 if (replacementInfo) { 381 reassociationMaps.push_back(replacementInfo->reassociation); 382 newIndexingMaps.push_back(replacementInfo->indexMap); 383 newInputOutputTypes.push_back(replacementInfo->type); 384 doCanonicalization |= 385 replacementInfo->type != opOperand->get().getType(); 386 } else { 387 // If replaceUnitExtents cannot handle this case, maintain the same 388 // type, indexing map, and create a set of mappings representing an 389 // identity matrix. 390 newInputOutputTypes.push_back(opOperand->get().getType()); 391 newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); 392 int64_t origRank = genericOp.getRank(opOperand); 393 auto maps = llvm::to_vector<8>(llvm::map_range( 394 llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute { 395 return AffineMapAttr::get( 396 AffineMap::get(origRank, /*symbolCount = */ 0, 397 getAffineDimExpr(dim, context), context)); 398 })); 399 reassociationMaps.push_back(ArrayAttr::get(context, maps)); 400 } 401 } 402 403 // If the indexing maps of the result operation are not invertible (i.e. not 404 // legal), abort. 405 if (!doCanonicalization || 406 !inversePermutation(concatAffineMaps(newIndexingMaps))) 407 return failure(); 408 409 // If any operand type change, insert a reshape to convert from the original 410 // type to the new type. 411 // TODO: get rid of flattenedIdx which assumes operand order and contiguity. 412 unsigned flattenedIdx = 0; 413 auto insertReshapes = [&](ValueRange values) { 414 SmallVector<Value, 4> res; 415 res.reserve(values.size()); 416 for (auto operand : values) { 417 auto reshapedValue = 418 maybeCollapse(operand, newInputOutputTypes[flattenedIdx], 419 reassociationMaps[flattenedIdx], loc, rewriter); 420 assert(reshapedValue && 421 "expected ranked MemRef or Tensor operand type"); 422 res.push_back(reshapedValue); 423 ++flattenedIdx; 424 } 425 return res; 426 }; 427 428 SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs()); 429 SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs()); 430 431 // If any result type changes, insert a reshape to convert from the original 432 // type to the new type. 433 SmallVector<Type, 4> resultTypes; 434 resultTypes.reserve(genericOp.getNumResults()); 435 for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults())) 436 resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]); 437 GenericOp replacementOp = rewriter.create<GenericOp>( 438 loc, resultTypes, newInputs, newOutputs, newIndexingMaps, 439 llvm::to_vector<4>( 440 genericOp.iterator_types().template getAsValueRange<StringAttr>())); 441 rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(), 442 replacementOp.region().begin()); 443 444 // If any result tensor has a modified shape, then add reshape to recover 445 // the original shape. 446 SmallVector<Value, 4> resultReplacements; 447 for (auto result : llvm::enumerate(replacementOp.getResults())) { 448 unsigned index = result.index() + replacementOp.getNumInputs(); 449 auto origResultType = genericOp.getResult(result.index()).getType(); 450 451 auto newResult = maybeExpand(result.value(), origResultType, 452 reassociationMaps[index], loc, rewriter); 453 assert(newResult && 454 "unexpected output type other than ranked MemRef or Tensor"); 455 resultReplacements.push_back(newResult); 456 } 457 rewriter.replaceOp(genericOp, resultReplacements); 458 return success(); 459 } 460 }; 461 } // namespace 462 463 /// Get the reassociation maps to fold the result of a extract_slice (or source 464 /// of a insert_slice) operation with given offsets, and sizes to its 465 /// rank-reduced version. This is only done for the cases where the size is 1 466 /// and offset is 0. Strictly speaking the offset 0 is not required in general, 467 /// but non-zero offsets are not handled by SPIR-V backend at this point (and 468 /// potentially cannot be handled). 469 static Optional<SmallVector<ReassociationIndices>> 470 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) { 471 SmallVector<ReassociationIndices> reassociation; 472 ReassociationIndices curr; 473 for (auto it : llvm::enumerate(mixedSizes)) { 474 auto dim = it.index(); 475 auto size = it.value(); 476 curr.push_back(dim); 477 auto attr = size.dyn_cast<Attribute>(); 478 if (attr && attr.cast<IntegerAttr>().getInt() == 1) 479 continue; 480 reassociation.emplace_back(ReassociationIndices{}); 481 std::swap(reassociation.back(), curr); 482 } 483 // When the reassociations are not empty, then fold the remaining 484 // unit-dimensions into the last dimension. If the reassociations so far is 485 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. 486 if (!curr.empty() && !reassociation.empty()) 487 reassociation.back().append(curr.begin(), curr.end()); 488 return reassociation; 489 } 490 491 namespace { 492 /// Convert `extract_slice` operations to rank-reduced versions. 493 struct UseRankReducedExtractSliceOp 494 : public OpRewritePattern<tensor::ExtractSliceOp> { 495 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; 496 497 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, 498 PatternRewriter &rewriter) const override { 499 RankedTensorType resultType = sliceOp.getType(); 500 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); 501 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); 502 SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides(); 503 auto reassociation = getReassociationMapForFoldingUnitDims(sizes); 504 if (!reassociation || 505 reassociation->size() == static_cast<size_t>(resultType.getRank())) 506 return failure(); 507 auto rankReducedType = tensor::ExtractSliceOp::inferRankReducedResultType( 508 reassociation->size(), sliceOp.getSourceType(), 509 offsets, sizes, strides) 510 .cast<RankedTensorType>(); 511 512 Location loc = sliceOp.getLoc(); 513 Value newSlice = rewriter.create<tensor::ExtractSliceOp>( 514 loc, rankReducedType, sliceOp.source(), offsets, sizes, strides); 515 rewriter.replaceOpWithNewOp<TensorExpandShapeOp>(sliceOp, resultType, 516 newSlice, *reassociation); 517 return success(); 518 } 519 }; 520 521 /// Convert `insert_slice` operations to rank-reduced versions. 522 struct UseRankReducedInsertSliceOp 523 : public OpRewritePattern<tensor::InsertSliceOp> { 524 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern; 525 526 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, 527 PatternRewriter &rewriter) const override { 528 RankedTensorType sourceType = insertOp.getSourceType(); 529 SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets(); 530 SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes(); 531 SmallVector<OpFoldResult> strides = insertOp.getMixedStrides(); 532 auto reassociation = getReassociationMapForFoldingUnitDims(sizes); 533 if (!reassociation || 534 reassociation->size() == static_cast<size_t>(sourceType.getRank())) 535 return failure(); 536 Location loc = insertOp.getLoc(); 537 auto reshapedSource = rewriter.create<TensorCollapseShapeOp>( 538 loc, insertOp.source(), *reassociation); 539 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 540 insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(), 541 insertOp.getMixedSizes(), insertOp.getMixedStrides()); 542 return success(); 543 } 544 }; 545 } // namespace 546 547 /// Patterns that are used to canonicalize the use of unit-extent dims for 548 /// broadcasting. 549 void mlir::linalg::populateFoldUnitExtentDimsPatterns( 550 RewritePatternSet &patterns) { 551 auto *context = patterns.getContext(); 552 patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, 553 UseRankReducedExtractSliceOp, UseRankReducedInsertSliceOp>( 554 context); 555 TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context); 556 TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context); 557 } 558 559 namespace { 560 /// Pass that removes unit-extent dims within generic ops. 561 struct LinalgFoldUnitExtentDimsPass 562 : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> { 563 void runOnFunction() override { 564 FuncOp funcOp = getFunction(); 565 MLIRContext *context = funcOp.getContext(); 566 RewritePatternSet patterns(context); 567 if (foldOneTripLoopsOnly) 568 patterns.add<FoldUnitDimLoops>(context); 569 else 570 populateFoldUnitExtentDimsPatterns(patterns); 571 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 572 } 573 }; 574 } // namespace 575 576 std::unique_ptr<OperationPass<FuncOp>> 577 mlir::createLinalgFoldUnitExtentDimsPass() { 578 return std::make_unique<LinalgFoldUnitExtentDimsPass>(); 579 } 580