1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns/pass to remove usage of unit-extent dimensions 10 // to specify broadcasting in favor of more canonical representation of the 11 // computation 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/Transforms/FoldUtils.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 #include "llvm/Support/CommandLine.h" 26 #include "llvm/Support/Debug.h" 27 28 #define DEBUG_TYPE "linalg-drop-unit-dims" 29 30 using namespace mlir; 31 using namespace mlir::linalg; 32 33 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for 34 /// broadcasting. For example, 35 /// 36 /// ```mlir 37 /// #accesses = [ 38 /// affine_map<(d0, d1) -> (0, d1)>, 39 /// affine_map<(d0, d1) -> (d0, 0)>, 40 /// affine_map<(d0, d1) -> (d0, d1)> 41 /// ] 42 /// 43 /// #trait = { 44 /// args_in = 2, 45 /// args_out = 1, 46 /// indexing_maps = #accesses, 47 /// iterator_types = ["parallel", "parallel"], 48 /// library_call = "some_external_fn" 49 /// } 50 /// 51 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 52 /// tensor<5x5xf32> 53 /// { 54 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : 55 /// tensor<5xf32> into tensor<1x5xf32> 56 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : 57 /// tensor<5xf32> into tensor<5x1xf32> 58 /// %2 = linalg.generic #trait %0, %1 { 59 /// ^bb0(%arg2: f32, %arg3: f32): 60 /// %3 = addf %arg2, %arg3 : f32 61 /// linalg.yield %3 : f32 62 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> 63 /// return %2 : tensor<5x5xf32> 64 /// } 65 /// 66 /// would canonicalize to 67 /// 68 /// ```mlir 69 /// #accesses = [ 70 /// affine_map<(d0, d1) -> (d1)>, 71 /// affine_map<(d0, d1) -> (d0)>, 72 /// affine_map<(d0, d1) -> (d0, d1)> 73 /// ] 74 /// 75 /// #trait = { 76 /// args_in = 2, 77 /// args_out = 1, 78 /// indexing_maps = #accesses, 79 /// iterator_types = ["parallel", "parallel"], 80 /// library_call = "some_external_fn" 81 /// } 82 /// 83 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 84 /// tensor<5x5xf32> 85 /// { 86 /// %0 = linalg.generic #trait %arg0, %arg1 { 87 /// ^bb0(%arg2: f32, %arg3: f32): 88 /// %3 = addf %arg2, %arg3 : f32 89 /// linalg.yield %3 : f32 90 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> 91 /// return %0 : tensor<5x5xf32> 92 /// } 93 94 /// Given dims of the iteration space of a structured op that are known to be 95 /// single trip count (`unitDims`), return the indexing maps to use in the 96 /// canonicalized op with these dims removed, given the original `indexingMaps`. 97 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims, 98 ArrayRef<AffineMap> indexingMaps, 99 MLIRContext *context) { 100 if (indexingMaps.empty()) 101 return nullptr; 102 unsigned numIterationDims = indexingMaps.front().getNumDims(); 103 unsigned numSymbols = indexingMaps.front().getNumSymbols(); 104 105 // Compute the replacement for each dim expr. 106 SmallVector<AffineExpr, 4> dimReplacements; 107 dimReplacements.reserve(numIterationDims); 108 unsigned numKeptDims = 0; 109 for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) { 110 if (unitDims.count(dim)) 111 dimReplacements.push_back(getAffineConstantExpr(0, context)); 112 else 113 dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context)); 114 } 115 116 // Symbols remain the same. 117 SmallVector<AffineExpr, 4> symReplacements; 118 symReplacements.reserve(numSymbols); 119 for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols)) 120 symReplacements.push_back(getAffineSymbolExpr(symbol, context)); 121 122 SmallVector<AffineMap, 4> newIndexingMaps; 123 newIndexingMaps.reserve(indexingMaps.size()); 124 for (AffineMap operandMap : indexingMaps) { 125 // Expected indexing maps to have no symbols. 126 if (operandMap.getNumSymbols()) 127 return nullptr; 128 newIndexingMaps.push_back(simplifyAffineMap( 129 operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements, 130 numIterationDims - unitDims.size(), 131 numSymbols))); 132 } 133 134 // Check that the new index maps are invertible. If not, something went 135 // wrong, so abort. 136 if (!inversePermutation(concatAffineMaps(newIndexingMaps))) 137 return nullptr; 138 return ArrayAttr::get(context, 139 llvm::to_vector<4>(llvm::map_range( 140 newIndexingMaps, [](AffineMap map) -> Attribute { 141 return AffineMapAttr::get(map); 142 }))); 143 } 144 145 /// Update the index accesses of linalg operations having index semantics. 146 static void replaceUnitDimIndexOps(GenericOp genericOp, 147 const DenseSet<unsigned> &unitDims, 148 PatternRewriter &rewriter) { 149 assert(genericOp->getNumRegions() == 1 && 150 genericOp->getRegion(0).getBlocks().size() == 1 && 151 "expected generic operation to have one block."); 152 Block &block = genericOp->getRegion(0).front(); 153 154 for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps<IndexOp>())) { 155 OpBuilder::InsertionGuard guard(rewriter); 156 rewriter.setInsertionPoint(indexOp); 157 if (unitDims.count(indexOp.dim()) != 0) { 158 rewriter.replaceOpWithNewOp<ConstantIndexOp>(indexOp, 0); 159 } else { 160 // Update the dimension of the index operation if needed. 161 unsigned droppedDims = llvm::count_if( 162 unitDims, [&](unsigned dim) { return dim < indexOp.dim(); }); 163 if (droppedDims != 0) 164 rewriter.replaceOpWithNewOp<IndexOp>(indexOp, 165 indexOp.dim() - droppedDims); 166 } 167 } 168 } 169 170 namespace { 171 /// Pattern to fold unit-trip count loops in GenericOps. 172 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> { 173 using OpRewritePattern<GenericOp>::OpRewritePattern; 174 LogicalResult matchAndRewrite(GenericOp genericOp, 175 PatternRewriter &rewriter) const override { 176 SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps(); 177 if (indexingMaps.empty()) 178 return failure(); 179 180 // Check if any of the iteration dimensions are unit-trip count. They will 181 // end up being unit-trip count if they are used to index into a unit-dim 182 // tensor/memref. 183 AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); 184 if (!invertedMap) 185 return failure(); 186 SmallVector<int64_t> dims = genericOp.getStaticShape(); 187 188 // Find all the reduction iterators. Those need some special consideration 189 // (see below). 190 auto getLoopDimsOfType = 191 [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> { 192 SmallVector<AffineExpr> dimExprs; 193 getDimsOfType(genericOp, iteratorTypeName, dimExprs); 194 return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) { 195 return expr.cast<AffineDimExpr>().getPosition(); 196 })); 197 }; 198 auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName()); 199 200 DenseSet<unsigned> unitDims; 201 SmallVector<unsigned, 4> unitDimsReductionLoops; 202 ArrayAttr iteratorTypes = genericOp.iterator_types(); 203 for (auto expr : enumerate(invertedMap.getResults())) { 204 if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>()) 205 if (dims[dimExpr.getPosition()] == 1) { 206 if (isParallelIterator(iteratorTypes[expr.index()])) 207 unitDims.insert(expr.index()); 208 else if (isReductionIterator(iteratorTypes[expr.index()])) 209 unitDimsReductionLoops.push_back(expr.index()); 210 } 211 } 212 213 // Reduction loops can be dropped if there is at least one other reduction 214 // loop that is not dropped. This accounts for the initial value read in the 215 // reduction loop. 216 if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) { 217 if (unitDimsReductionLoops.size() == reductionDims.size()) 218 unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end())); 219 else 220 unitDims.insert(unitDimsReductionLoops.begin(), 221 unitDimsReductionLoops.end()); 222 } 223 224 if (unitDims.empty()) 225 return failure(); 226 227 // Compute the modified indexing maps. 228 MLIRContext *context = rewriter.getContext(); 229 ArrayAttr newIndexingMapAttr = 230 replaceUnitDims(unitDims, indexingMaps, context); 231 if (!newIndexingMapAttr) 232 return genericOp.emitError("unable to compute modified indexing_maps"); 233 234 // Compute the iterator types of the modified op by dropping the one-trip 235 // count loops. 236 SmallVector<Attribute, 4> newIteratorTypes; 237 for (auto attr : llvm::enumerate(iteratorTypes)) { 238 if (!unitDims.count(attr.index())) 239 newIteratorTypes.push_back(attr.value()); 240 } 241 242 rewriter.startRootUpdate(genericOp); 243 genericOp.indexing_mapsAttr(newIndexingMapAttr); 244 genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes)); 245 replaceUnitDimIndexOps(genericOp, unitDims, rewriter); 246 rewriter.finalizeRootUpdate(genericOp); 247 return success(); 248 } 249 }; 250 251 struct UnitExtentReplacementInfo { 252 Type type; 253 AffineMap indexMap; 254 ArrayAttr reassociation; 255 }; 256 } // namespace 257 258 /// Utility function for replacing operands/results to a linalg generic 259 /// operation on tensors with unit-extent dimensions. These can be replaced with 260 /// an operand/result with the unit-extent dimension removed. This is only done 261 /// if the indexing map used to access that didimensionmension has a 262 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a 263 /// Linalg op, and its `indexMap` the utility function returns: 264 /// - the new type with dimensions of size 1 removed. 265 /// - modified index map that can be used to access the replaced result/operand 266 /// - the reassociation that converts from the original tensor type to the 267 /// modified tensor type. 268 static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp, 269 OpOperand *opOperand, 270 MLIRContext *context) { 271 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 272 ArrayRef<int64_t> shape = genericOp.getShape(opOperand); 273 ArrayRef<AffineExpr> exprs = indexingMap.getResults(); 274 SmallVector<AffineExpr> reassociations; 275 SmallVector<Attribute> reassociationMaps; 276 SmallVector<AffineExpr> newIndexExprs; 277 SmallVector<int64_t> newShape; 278 279 int64_t origRank = genericOp.getRank(opOperand); 280 AffineExpr zeroExpr = getAffineConstantExpr(0, context); 281 auto isUnitExtent = [&](int64_t dim) -> bool { 282 return shape[dim] == 1 && exprs[dim] == zeroExpr; 283 }; 284 285 int64_t dim = 0; 286 // Fold dimensions that are unit-extent at the beginning of the tensor. 287 while (dim < origRank && isUnitExtent(dim)) 288 reassociations.push_back(getAffineDimExpr(dim++, context)); 289 while (dim < origRank) { 290 reassociations.push_back(getAffineDimExpr(dim, context)); 291 newIndexExprs.push_back(exprs[dim]); 292 newShape.push_back(shape[dim]); 293 // Fold all following dimensions that are unit-extent. 294 while (dim + 1 < origRank && isUnitExtent(dim + 1)) { 295 ++dim; 296 reassociations.push_back(getAffineDimExpr(dim, context)); 297 } 298 reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get( 299 origRank, /*symbolCount = */ 0, reassociations, context))); 300 reassociations.clear(); 301 ++dim; 302 } 303 // Compute the tensor or scalar replacement type. 304 Type elementType = getElementTypeOrSelf(opOperand->get().getType()); 305 Type replacementType = elementType == opOperand->get().getType() 306 ? elementType 307 : RankedTensorType::get(newShape, elementType); 308 UnitExtentReplacementInfo info = {replacementType, 309 AffineMap::get(indexingMap.getNumDims(), 310 indexingMap.getNumSymbols(), 311 newIndexExprs, context), 312 ArrayAttr::get(context, reassociationMaps)}; 313 return info; 314 } 315 316 namespace { 317 318 SmallVector<ReassociationExprs, 2> 319 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) { 320 SmallVector<ReassociationExprs, 2> reassociationExprs; 321 for (auto attr : affineMapArrayAttr) 322 reassociationExprs.push_back( 323 llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults())); 324 return reassociationExprs; 325 } 326 327 /// Pattern to replace tensors operands/results that are unit extents. 328 struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> { 329 using OpRewritePattern<GenericOp>::OpRewritePattern; 330 LogicalResult matchAndRewrite(GenericOp genericOp, 331 PatternRewriter &rewriter) const override { 332 if (!genericOp.hasTensorSemantics()) 333 return failure(); 334 335 MLIRContext *context = rewriter.getContext(); 336 Location loc = genericOp.getLoc(); 337 338 SmallVector<AffineMap> newIndexingMaps; 339 SmallVector<ArrayAttr> reassociationMaps; 340 SmallVector<Type> newInputOutputTypes; 341 bool doCanonicalization = false; 342 343 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 344 UnitExtentReplacementInfo replacementInfo = 345 replaceUnitExtents(genericOp, opOperand, context); 346 reassociationMaps.push_back(replacementInfo.reassociation); 347 newIndexingMaps.push_back(replacementInfo.indexMap); 348 newInputOutputTypes.push_back(replacementInfo.type); 349 doCanonicalization |= replacementInfo.type != opOperand->get().getType(); 350 } 351 352 // If the indexing maps of the result operation are not invertible (i.e. not 353 // legal), abort. 354 if (!doCanonicalization || 355 !inversePermutation(concatAffineMaps(newIndexingMaps))) 356 return failure(); 357 358 // If any operand type change, insert a reshape to convert from the original 359 // type to the new type. 360 // TODO: get rid of flattenedIdx which assumes operand order and contiguity. 361 unsigned flattenedIdx = 0; 362 auto insertReshapes = [&](ValueRange values) { 363 SmallVector<Value, 4> res; 364 res.reserve(values.size()); 365 for (auto operand : llvm::enumerate(values)) { 366 if (operand.value().getType() == newInputOutputTypes[flattenedIdx]) 367 res.push_back(operand.value()); 368 else { 369 res.push_back(rewriter.create<TensorCollapseShapeOp>( 370 loc, newInputOutputTypes[flattenedIdx], operand.value(), 371 convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx]))); 372 } 373 ++flattenedIdx; 374 } 375 return res; 376 }; 377 378 SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs()); 379 SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs()); 380 381 // If any result type changes, insert a reshape to convert from the original 382 // type to the new type. 383 SmallVector<Type, 4> resultTypes; 384 resultTypes.reserve(genericOp.getNumResults()); 385 for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults())) 386 resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]); 387 GenericOp replacementOp = rewriter.create<GenericOp>( 388 loc, resultTypes, newInputs, newOutputs, newIndexingMaps, 389 llvm::to_vector<4>( 390 genericOp.iterator_types().template getAsValueRange<StringAttr>())); 391 rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(), 392 replacementOp.region().begin()); 393 394 // If any result tensor has a modified shape, then add reshape to recover 395 // the original shape. 396 SmallVector<Value, 4> resultReplacements; 397 for (auto result : llvm::enumerate(replacementOp.getResults())) { 398 unsigned index = result.index() + replacementOp.getNumInputs(); 399 RankedTensorType origResultType = genericOp.getResult(result.index()) 400 .getType() 401 .template cast<RankedTensorType>(); 402 if (origResultType != result.value().getType()) { 403 resultReplacements.push_back(rewriter.create<TensorExpandShapeOp>( 404 loc, origResultType, result.value(), 405 convertAffineMapArrayToExprs(reassociationMaps[index]))); 406 } else 407 resultReplacements.push_back(result.value()); 408 } 409 rewriter.replaceOp(genericOp, resultReplacements); 410 return success(); 411 } 412 }; 413 } // namespace 414 415 /// Get the reassociation maps to fold the result of a subtensor (or source of a 416 /// subtensor_insert) operation with given offsets, and sizes to its 417 /// rank-reduced version. This is only done for the cases where the size is 1 418 /// and offset is 0. Strictly speaking the offset 0 is not required in general, 419 /// but non-zero offsets are not handled by SPIR-V backend at this point (and 420 /// potentially cannot be handled). 421 static Optional<SmallVector<ReassociationIndices>> 422 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) { 423 SmallVector<ReassociationIndices> reassociation; 424 ReassociationIndices curr; 425 for (auto it : llvm::enumerate(mixedSizes)) { 426 auto dim = it.index(); 427 auto size = it.value(); 428 curr.push_back(dim); 429 auto attr = size.dyn_cast<Attribute>(); 430 if (attr && attr.cast<IntegerAttr>().getInt() == 1) 431 continue; 432 reassociation.emplace_back(ReassociationIndices{}); 433 std::swap(reassociation.back(), curr); 434 } 435 // When the reassociations are not empty, then fold the remaining 436 // unit-dimensions into the last dimension. If the reassociations so far is 437 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. 438 if (!curr.empty() && !reassociation.empty()) 439 reassociation.back().append(curr.begin(), curr.end()); 440 return reassociation; 441 } 442 443 namespace { 444 /// Convert `subtensor` operations to rank-reduced versions. 445 struct UseRankReducedSubTensorOp : public OpRewritePattern<SubTensorOp> { 446 using OpRewritePattern<SubTensorOp>::OpRewritePattern; 447 448 LogicalResult matchAndRewrite(SubTensorOp subTensorOp, 449 PatternRewriter &rewriter) const override { 450 RankedTensorType resultType = subTensorOp.getType(); 451 SmallVector<OpFoldResult> offsets = subTensorOp.getMixedOffsets(); 452 SmallVector<OpFoldResult> sizes = subTensorOp.getMixedSizes(); 453 SmallVector<OpFoldResult> strides = subTensorOp.getMixedStrides(); 454 auto reassociation = getReassociationMapForFoldingUnitDims(sizes); 455 if (!reassociation || 456 reassociation->size() == static_cast<size_t>(resultType.getRank())) 457 return failure(); 458 auto rankReducedType = 459 SubTensorOp::inferRankReducedResultType(reassociation->size(), 460 subTensorOp.getSourceType(), 461 offsets, sizes, strides) 462 .cast<RankedTensorType>(); 463 464 Location loc = subTensorOp.getLoc(); 465 Value newSubTensor = rewriter.create<SubTensorOp>( 466 loc, rankReducedType, subTensorOp.source(), offsets, sizes, strides); 467 rewriter.replaceOpWithNewOp<TensorExpandShapeOp>( 468 subTensorOp, resultType, newSubTensor, *reassociation); 469 return success(); 470 } 471 }; 472 473 /// Convert `subtensor_insert` operations to rank-reduced versions. 474 struct UseRankReducedSubTensorInsertOp 475 : public OpRewritePattern<SubTensorInsertOp> { 476 using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern; 477 478 LogicalResult matchAndRewrite(SubTensorInsertOp insertOp, 479 PatternRewriter &rewriter) const override { 480 RankedTensorType sourceType = insertOp.getSourceType(); 481 SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets(); 482 SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes(); 483 SmallVector<OpFoldResult> strides = insertOp.getMixedStrides(); 484 auto reassociation = getReassociationMapForFoldingUnitDims(sizes); 485 if (!reassociation || 486 reassociation->size() == static_cast<size_t>(sourceType.getRank())) 487 return failure(); 488 Location loc = insertOp.getLoc(); 489 auto reshapedSource = rewriter.create<TensorCollapseShapeOp>( 490 loc, insertOp.source(), *reassociation); 491 rewriter.replaceOpWithNewOp<SubTensorInsertOp>( 492 insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(), 493 insertOp.getMixedSizes(), insertOp.getMixedStrides()); 494 return success(); 495 } 496 }; 497 } // namespace 498 499 /// Patterns that are used to canonicalize the use of unit-extent dims for 500 /// broadcasting. 501 void mlir::linalg::populateFoldUnitExtentDimsPatterns( 502 RewritePatternSet &patterns) { 503 auto *context = patterns.getContext(); 504 patterns.add<FoldUnitDimLoops, ReplaceUnitExtentTensors, 505 UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>( 506 context); 507 TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context); 508 TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context); 509 } 510 511 namespace { 512 /// Pass that removes unit-extent dims within generic ops. 513 struct LinalgFoldUnitExtentDimsPass 514 : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> { 515 void runOnFunction() override { 516 FuncOp funcOp = getFunction(); 517 MLIRContext *context = funcOp.getContext(); 518 RewritePatternSet patterns(context); 519 if (foldOneTripLoopsOnly) 520 patterns.add<FoldUnitDimLoops>(context); 521 else 522 populateFoldUnitExtentDimsPatterns(patterns); 523 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 524 } 525 }; 526 } // namespace 527 528 std::unique_ptr<OperationPass<FuncOp>> 529 mlir::createLinalgFoldUnitExtentDimsPass() { 530 return std::make_unique<LinalgFoldUnitExtentDimsPass>(); 531 } 532