1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns/pass to remove usage of unit-extent dimensions 10 // to specify broadcasting in favor of more canonical representation of the 11 // computation 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/BuiltinTypes.h" 25 #include "mlir/Transforms/FoldUtils.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "llvm/Support/CommandLine.h" 28 #include "llvm/Support/Debug.h" 29 30 #define DEBUG_TYPE "linalg-drop-unit-dims" 31 32 using namespace mlir; 33 using namespace mlir::linalg; 34 35 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for 36 /// broadcasting. For example, 37 /// 38 /// ```mlir 39 /// #accesses = [ 40 /// affine_map<(d0, d1) -> (0, d1)>, 41 /// affine_map<(d0, d1) -> (d0, 0)>, 42 /// affine_map<(d0, d1) -> (d0, d1)> 43 /// ] 44 /// 45 /// #trait = { 46 /// args_in = 2, 47 /// args_out = 1, 48 /// indexing_maps = #accesses, 49 /// iterator_types = ["parallel", "parallel"], 50 /// library_call = "some_external_fn" 51 /// } 52 /// 53 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 54 /// tensor<5x5xf32> 55 /// { 56 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : 57 /// tensor<5xf32> into tensor<1x5xf32> 58 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : 59 /// tensor<5xf32> into tensor<5x1xf32> 60 /// %2 = linalg.generic #trait %0, %1 { 61 /// ^bb0(%arg2: f32, %arg3: f32): 62 /// %3 = addf %arg2, %arg3 : f32 63 /// linalg.yield %3 : f32 64 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> 65 /// return %2 : tensor<5x5xf32> 66 /// } 67 /// 68 /// would canonicalize to 69 /// 70 /// ```mlir 71 /// #accesses = [ 72 /// affine_map<(d0, d1) -> (d1)>, 73 /// affine_map<(d0, d1) -> (d0)>, 74 /// affine_map<(d0, d1) -> (d0, d1)> 75 /// ] 76 /// 77 /// #trait = { 78 /// args_in = 2, 79 /// args_out = 1, 80 /// indexing_maps = #accesses, 81 /// iterator_types = ["parallel", "parallel"], 82 /// library_call = "some_external_fn" 83 /// } 84 /// 85 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 86 /// tensor<5x5xf32> 87 /// { 88 /// %0 = linalg.generic #trait %arg0, %arg1 { 89 /// ^bb0(%arg2: f32, %arg3: f32): 90 /// %3 = addf %arg2, %arg3 : f32 91 /// linalg.yield %3 : f32 92 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> 93 /// return %0 : tensor<5x5xf32> 94 /// } 95 96 /// Given dims of the iteration space of a structured op that are known to be 97 /// single trip count (`unitDims`), return the indexing maps to use in the 98 /// canonicalized op with these dims removed, given the original `indexingMaps`. 99 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims, 100 ArrayRef<AffineMap> indexingMaps, 101 MLIRContext *context) { 102 if (indexingMaps.empty()) 103 return nullptr; 104 unsigned numIterationDims = indexingMaps.front().getNumDims(); 105 unsigned numSymbols = indexingMaps.front().getNumSymbols(); 106 107 // Compute the replacement for each dim expr. 108 SmallVector<AffineExpr, 4> dimReplacements; 109 dimReplacements.reserve(numIterationDims); 110 unsigned numKeptDims = 0; 111 for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) { 112 if (unitDims.count(dim)) 113 dimReplacements.push_back(getAffineConstantExpr(0, context)); 114 else 115 dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context)); 116 } 117 118 // Symbols remain the same. 119 SmallVector<AffineExpr, 4> symReplacements; 120 symReplacements.reserve(numSymbols); 121 for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols)) 122 symReplacements.push_back(getAffineSymbolExpr(symbol, context)); 123 124 SmallVector<AffineMap, 4> newIndexingMaps; 125 newIndexingMaps.reserve(indexingMaps.size()); 126 for (AffineMap operandMap : indexingMaps) { 127 // Expected indexing maps to have no symbols. 128 if (operandMap.getNumSymbols()) 129 return nullptr; 130 newIndexingMaps.push_back(simplifyAffineMap( 131 operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements, 132 numIterationDims - unitDims.size(), 133 numSymbols))); 134 } 135 136 // Check that the new index maps are invertible. If not, something went 137 // wrong, so abort. 138 if (!inversePermutation(concatAffineMaps(newIndexingMaps))) 139 return nullptr; 140 return ArrayAttr::get(context, 141 llvm::to_vector<4>(llvm::map_range( 142 newIndexingMaps, [](AffineMap map) -> Attribute { 143 return AffineMapAttr::get(map); 144 }))); 145 } 146 147 /// Update the index accesses of linalg operations having index semantics. 148 static void replaceUnitDimIndexOps(GenericOp genericOp, 149 const DenseSet<unsigned> &unitDims, 150 PatternRewriter &rewriter) { 151 assert(genericOp->getNumRegions() == 1 && 152 genericOp->getRegion(0).getBlocks().size() == 1 && 153 "expected generic operation to have one block."); 154 Block &block = genericOp->getRegion(0).front(); 155 156 for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps<IndexOp>())) { 157 OpBuilder::InsertionGuard guard(rewriter); 158 rewriter.setInsertionPoint(indexOp); 159 if (unitDims.count(indexOp.dim()) != 0) { 160 rewriter.replaceOpWithNewOp<ConstantIndexOp>(indexOp, 0); 161 } else { 162 // Update the dimension of the index operation if needed. 163 unsigned droppedDims = llvm::count_if( 164 unitDims, [&](unsigned dim) { return dim < indexOp.dim(); }); 165 if (droppedDims != 0) 166 rewriter.replaceOpWithNewOp<IndexOp>(indexOp, 167 indexOp.dim() - droppedDims); 168 } 169 } 170 } 171 172 namespace { 173 /// Pattern to fold unit-trip count loops in GenericOps. 174 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> { 175 using OpRewritePattern<GenericOp>::OpRewritePattern; 176 LogicalResult matchAndRewrite(GenericOp genericOp, 177 PatternRewriter &rewriter) const override { 178 SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps(); 179 if (indexingMaps.empty()) 180 return failure(); 181 182 // Check if any of the iteration dimensions are unit-trip count. They will 183 // end up being unit-trip count if they are used to index into a unit-dim 184 // tensor/memref. 185 AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); 186 if (!invertedMap) 187 return failure(); 188 SmallVector<int64_t> dims = genericOp.getStaticShape(); 189 190 // Find all the reduction iterators. Those need some special consideration 191 // (see below). 192 auto getLoopDimsOfType = 193 [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> { 194 SmallVector<AffineExpr> dimExprs; 195 getDimsOfType(genericOp, iteratorTypeName, dimExprs); 196 return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) { 197 return expr.cast<AffineDimExpr>().getPosition(); 198 })); 199 }; 200 auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName()); 201 202 DenseSet<unsigned> unitDims; 203 SmallVector<unsigned, 4> unitDimsReductionLoops; 204 ArrayAttr iteratorTypes = genericOp.iterator_types(); 205 for (auto expr : enumerate(invertedMap.getResults())) { 206 if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>()) 207 if (dims[dimExpr.getPosition()] == 1) { 208 if (isParallelIterator(iteratorTypes[expr.index()])) 209 unitDims.insert(expr.index()); 210 else if (isReductionIterator(iteratorTypes[expr.index()])) 211 unitDimsReductionLoops.push_back(expr.index()); 212 } 213 } 214 215 // Reduction loops can be dropped if there is at least one other reduction 216 // loop that is not dropped. This accounts for the initial value read in the 217 // reduction loop. 218 if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) { 219 if (unitDimsReductionLoops.size() == reductionDims.size()) 220 unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end())); 221 else 222 unitDims.insert(unitDimsReductionLoops.begin(), 223 unitDimsReductionLoops.end()); 224 } 225 226 if (unitDims.empty()) 227 return failure(); 228 229 // Compute the modified indexing maps. 230 MLIRContext *context = rewriter.getContext(); 231 ArrayAttr newIndexingMapAttr = 232 replaceUnitDims(unitDims, indexingMaps, context); 233 if (!newIndexingMapAttr) 234 return genericOp.emitError("unable to compute modified indexing_maps"); 235 236 // Compute the iterator types of the modified op by dropping the one-trip 237 // count loops. 238 SmallVector<Attribute, 4> newIteratorTypes; 239 for (auto attr : llvm::enumerate(iteratorTypes)) { 240 if (!unitDims.count(attr.index())) 241 newIteratorTypes.push_back(attr.value()); 242 } 243 244 rewriter.startRootUpdate(genericOp); 245 genericOp.indexing_mapsAttr(newIndexingMapAttr); 246 genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes)); 247 replaceUnitDimIndexOps(genericOp, unitDims, rewriter); 248 rewriter.finalizeRootUpdate(genericOp); 249 return success(); 250 } 251 }; 252 253 struct UnitExtentReplacementInfo { 254 Type type; 255 AffineMap indexMap; 256 ArrayAttr reassociation; 257 }; 258 } // namespace 259 260 /// Utility function for replacing operands/results to a linalg generic 261 /// operation with unit-extent dimensions. These can be replaced with 262 /// an operand/result with the unit-extent dimension removed. This is only done 263 /// if the indexing map used to access that didimensionmension has a 264 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a 265 /// Linalg op, and its `indexMap` the utility function returns: 266 /// - the new type with dimensions of size 1 removed. 267 /// - modified index map that can be used to access the replaced result/operand 268 /// - the reassociation that converts from the original tensor type to the 269 /// modified tensor type. 270 static llvm::Optional<UnitExtentReplacementInfo> 271 replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand, 272 MLIRContext *context) { 273 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 274 ArrayRef<int64_t> shape = genericOp.getShape(opOperand); 275 ArrayRef<AffineExpr> exprs = indexingMap.getResults(); 276 SmallVector<AffineExpr> reassociations; 277 SmallVector<Attribute> reassociationMaps; 278 SmallVector<AffineExpr> newIndexExprs; 279 SmallVector<int64_t> newShape; 280 281 int64_t origRank = genericOp.getRank(opOperand); 282 AffineExpr zeroExpr = getAffineConstantExpr(0, context); 283 auto isUnitExtent = [&](int64_t dim) -> bool { 284 return shape[dim] == 1 && exprs[dim] == zeroExpr; 285 }; 286 287 // Early return for memrefs with affine maps to represent that we will always 288 // leave them unchanged. 289 Type actualType = opOperand->get().getType(); 290 if (auto memref = actualType.dyn_cast<MemRefType>()) { 291 if (!memref.getAffineMaps().empty()) 292 return llvm::None; 293 } 294 295 int64_t dim = 0; 296 // Fold dimensions that are unit-extent at the beginning of the tensor. 297 while (dim < origRank && isUnitExtent(dim)) 298 reassociations.push_back(getAffineDimExpr(dim++, context)); 299 while (dim < origRank) { 300 reassociations.push_back(getAffineDimExpr(dim, context)); 301 newIndexExprs.push_back(exprs[dim]); 302 newShape.push_back(shape[dim]); 303 // Fold all following dimensions that are unit-extent. 304 while (dim + 1 < origRank && isUnitExtent(dim + 1)) { 305 ++dim; 306 reassociations.push_back(getAffineDimExpr(dim, context)); 307 } 308 reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get( 309 origRank, /*symbolCount = */ 0, reassociations, context))); 310 reassociations.clear(); 311 ++dim; 312 } 313 314 // Compute the tensor or scalar replacement type. 315 Type elementType = getElementTypeOrSelf(opOperand->get()); 316 Type replacementType; 317 if (elementType == opOperand->get().getType()) { 318 replacementType = elementType; 319 } else if (actualType.isa<RankedTensorType>()) { 320 replacementType = RankedTensorType::get(newShape, elementType); 321 } else if (actualType.isa<MemRefType>()) { 322 replacementType = MemRefType::get(newShape, elementType); 323 } 324 assert(replacementType && "unsupported shaped type"); 325 UnitExtentReplacementInfo info = {replacementType, 326 AffineMap::get(indexingMap.getNumDims(), 327 indexingMap.getNumSymbols(), 328 newIndexExprs, context), 329 ArrayAttr::get(context, reassociationMaps)}; 330 return info; 331 } 332 333 namespace { 334 335 SmallVector<ReassociationExprs, 2> 336 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) { 337 SmallVector<ReassociationExprs, 2> reassociationExprs; 338 for (auto attr : affineMapArrayAttr) 339 reassociationExprs.push_back( 340 llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults())); 341 return reassociationExprs; 342 } 343 344 /// Pattern to replace tensor/buffer operands/results that are unit extents. 345 struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> { 346 using OpRewritePattern<GenericOp>::OpRewritePattern; 347 348 // Return the original value if the type is unchanged, or reshape it. Return a 349 // nullptr if this is an unsupported type. 350 Value maybeExpand(Value result, Type origResultType, 351 ArrayAttr reassociationMap, Location loc, 352 PatternRewriter &rewriter) const { 353 if (origResultType == result.getType()) 354 return result; 355 if (origResultType.isa<RankedTensorType>()) { 356 return rewriter.create<linalg::TensorExpandShapeOp>( 357 loc, origResultType, result, 358 convertAffineMapArrayToExprs(reassociationMap)); 359 } 360 if (origResultType.isa<MemRefType>()) { 361 return rewriter.create<memref::ExpandShapeOp>( 362 loc, origResultType, result, 363 convertAffineMapArrayToExprs(reassociationMap)); 364 } 365 return nullptr; 366 }; 367 368 // Return the original value if the type is unchanged, or reshape it. Return a 369 // nullptr if this is an unsupported type. 370 Value maybeCollapse(Value operand, Type newInputOutputType, 371 ArrayAttr reassociationMap, Location loc, 372 PatternRewriter &rewriter) const { 373 auto operandType = operand.getType(); 374 if (operandType == newInputOutputType) 375 return operand; 376 if (operandType.isa<MemRefType>()) { 377 return rewriter.create<memref::CollapseShapeOp>( 378 loc, newInputOutputType, operand, 379 convertAffineMapArrayToExprs(reassociationMap)); 380 } 381 if (operandType.isa<RankedTensorType>()) { 382 return rewriter.create<linalg::TensorCollapseShapeOp>( 383 loc, newInputOutputType, operand, 384 convertAffineMapArrayToExprs(reassociationMap)); 385 } 386 return nullptr; 387 }; 388 389 LogicalResult matchAndRewrite(GenericOp genericOp, 390 PatternRewriter &rewriter) const override { 391 MLIRContext *context = rewriter.getContext(); 392 Location loc = genericOp.getLoc(); 393 394 SmallVector<AffineMap> newIndexingMaps; 395 SmallVector<ArrayAttr> reassociationMaps; 396 SmallVector<Type> newInputOutputTypes; 397 bool doCanonicalization = false; 398 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 399 auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context); 400 if (replacementInfo) { 401 reassociationMaps.push_back(replacementInfo->reassociation); 402 newIndexingMaps.push_back(replacementInfo->indexMap); 403 newInputOutputTypes.push_back(replacementInfo->type); 404 doCanonicalization |= 405 replacementInfo->type != opOperand->get().getType(); 406 } else { 407 // If replaceUnitExtents cannot handle this case, maintain the same 408 // type, indexing map, and create a set of mappings representing an 409 // identity matrix. 410 newInputOutputTypes.push_back(opOperand->get().getType()); 411 newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); 412 int64_t origRank = genericOp.getRank(opOperand); 413 auto maps = llvm::to_vector<8>(llvm::map_range( 414 llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute { 415 return AffineMapAttr::get( 416 AffineMap::get(origRank, /*symbolCount = */ 0, 417 getAffineDimExpr(dim, context), context)); 418 })); 419 reassociationMaps.push_back(ArrayAttr::get(context, maps)); 420 } 421 } 422 423 // If the indexing maps of the result operation are not invertible (i.e. not 424 // legal), abort. 425 if (!doCanonicalization || 426 !inversePermutation(concatAffineMaps(newIndexingMaps))) 427 return failure(); 428 429 // If any operand type change, insert a reshape to convert from the original 430 // type to the new type. 431 // TODO: get rid of flattenedIdx which assumes operand order and contiguity. 432 unsigned flattenedIdx = 0; 433 auto insertReshapes = [&](ValueRange values) { 434 SmallVector<Value, 4> res; 435 res.reserve(values.size()); 436 for (auto operand : values) { 437 auto reshapedValue = 438 maybeCollapse(operand, newInputOutputTypes[flattenedIdx], 439 reassociationMaps[flattenedIdx], loc, rewriter); 440 assert(reshapedValue && 441 "expected ranked MemRef or Tensor operand type"); 442 res.push_back(reshapedValue); 443 ++flattenedIdx; 444 } 445 return res; 446 }; 447 448 SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs()); 449 SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs()); 450 451 // If any result type changes, insert a reshape to convert from the original 452 // type to the new type. 453 SmallVector<Type, 4> resultTypes; 454 resultTypes.reserve(genericOp.getNumResults()); 455 for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults())) 456 resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]); 457 GenericOp replacementOp = rewriter.create<GenericOp>( 458 loc, resultTypes, newInputs, newOutputs, newIndexingMaps, 459 llvm::to_vector<4>( 460 genericOp.iterator_types().template getAsValueRange<StringAttr>())); 461 rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(), 462 replacementOp.region().begin()); 463 464 // If any result tensor has a modified shape, then add reshape to recover 465 // the original shape. 466 SmallVector<Value, 4> resultReplacements; 467 for (auto result : llvm::enumerate(replacementOp.getResults())) { 468 unsigned index = result.index() + replacementOp.getNumInputs(); 469 auto origResultType = genericOp.getResult(result.index()).getType(); 470 471 auto newResult = maybeExpand(result.value(), origResultType, 472 reassociationMaps[index], loc, rewriter); 473 assert(newResult && 474 "unexpected output type other than ranked MemRef or Tensor"); 475 resultReplacements.push_back(newResult); 476 } 477 rewriter.replaceOp(genericOp, resultReplacements); 478 return success(); 479 } 480 }; 481 } // namespace 482 483 /// Get the reassociation maps to fold the result of a extract_slice (or source 484 /// of a insert_slice) operation with given offsets, and sizes to its 485 /// rank-reduced version. This is only done for the cases where the size is 1 486 /// and offset is 0. Strictly speaking the offset 0 is not required in general, 487 /// but non-zero offsets are not handled by SPIR-V backend at this point (and 488 /// potentially cannot be handled). 489 static Optional<SmallVector<ReassociationIndices>> 490 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) { 491 SmallVector<ReassociationIndices> reassociation; 492 ReassociationIndices curr; 493 for (auto it : llvm::enumerate(mixedSizes)) { 494 auto dim = it.index(); 495 auto size = it.value(); 496 curr.push_back(dim); 497 auto attr = size.dyn_cast<Attribute>(); 498 if (attr && attr.cast<IntegerAttr>().getInt() == 1) 499 continue; 500 reassociation.emplace_back(ReassociationIndices{}); 501 std::swap(reassociation.back(), curr); 502 } 503 // When the reassociations are not empty, then fold the remaining 504 // unit-dimensions into the last dimension. If the reassociations so far is 505 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. 506 if (!curr.empty() && !reassociation.empty()) 507 reassociation.back().append(curr.begin(), curr.end()); 508 return reassociation; 509 } 510 511 namespace { 512 /// Convert `extract_slice` operations to rank-reduced versions. 513 struct UseRankReducedExtractSliceOp 514 : public OpRewritePattern<tensor::ExtractSliceOp> { 515 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; 516 517 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, 518 PatternRewriter &rewriter) const override { 519 RankedTensorType resultType = sliceOp.getType(); 520 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); 521 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); 522 SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides(); 523 auto reassociation = getReassociationMapForFoldingUnitDims(sizes); 524 if (!reassociation || 525 reassociation->size() == static_cast<size_t>(resultType.getRank())) 526 return failure(); 527 auto rankReducedType = tensor::ExtractSliceOp::inferRankReducedResultType( 528 reassociation->size(), sliceOp.getSourceType(), 529 offsets, sizes, strides) 530 .cast<RankedTensorType>(); 531 532 Location loc = sliceOp.getLoc(); 533 Value newSlice = rewriter.create<tensor::ExtractSliceOp>( 534 loc, rankReducedType, sliceOp.source(), offsets, sizes, strides); 535 rewriter.replaceOpWithNewOp<TensorExpandShapeOp>(sliceOp, resultType, 536 newSlice, *reassociation); 537 return success(); 538 } 539 }; 540 541 /// Convert `insert_slice` operations to rank-reduced versions. 542 struct UseRankReducedInsertSliceOp 543 : public OpRewritePattern<tensor::InsertSliceOp> { 544 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern; 545 546 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, 547 PatternRewriter &rewriter) const override { 548 RankedTensorType sourceType = insertOp.getSourceType(); 549 SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets(); 550 SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes(); 551 SmallVector<OpFoldResult> strides = insertOp.getMixedStrides(); 552 auto reassociation = getReassociationMapForFoldingUnitDims(sizes); 553 if (!reassociation || 554 reassociation->size() == static_cast<size_t>(sourceType.getRank())) 555 return failure(); 556 Location loc = insertOp.getLoc(); 557 auto reshapedSource = rewriter.create<TensorCollapseShapeOp>( 558 loc, insertOp.source(), *reassociation); 559 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 560 insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(), 561 insertOp.getMixedSizes(), insertOp.getMixedStrides()); 562 return success(); 563 } 564 }; 565 } // namespace 566 567 /// Patterns that are used to canonicalize the use of unit-extent dims for 568 /// broadcasting. 569 void mlir::linalg::populateFoldUnitExtentDimsPatterns( 570 RewritePatternSet &patterns) { 571 auto *context = patterns.getContext(); 572 patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, 573 UseRankReducedExtractSliceOp, UseRankReducedInsertSliceOp>( 574 context); 575 TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context); 576 TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context); 577 } 578 579 namespace { 580 /// Pass that removes unit-extent dims within generic ops. 581 struct LinalgFoldUnitExtentDimsPass 582 : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> { 583 void runOnFunction() override { 584 FuncOp funcOp = getFunction(); 585 MLIRContext *context = funcOp.getContext(); 586 RewritePatternSet patterns(context); 587 if (foldOneTripLoopsOnly) 588 patterns.add<FoldUnitDimLoops>(context); 589 else 590 populateFoldUnitExtentDimsPatterns(patterns); 591 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 592 } 593 }; 594 } // namespace 595 596 std::unique_ptr<OperationPass<FuncOp>> 597 mlir::createLinalgFoldUnitExtentDimsPass() { 598 return std::make_unique<LinalgFoldUnitExtentDimsPass>(); 599 } 600