1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns/pass to remove usage of unit-extent dimensions 10 // to specify broadcasting in favor of more canonical representation of the 11 // computation 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/Transforms/FoldUtils.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 #include "llvm/Support/CommandLine.h" 27 #include "llvm/Support/Debug.h" 28 29 #define DEBUG_TYPE "linalg-drop-unit-dims" 30 31 using namespace mlir; 32 using namespace mlir::edsc; 33 using namespace mlir::edsc::intrinsics; 34 using namespace mlir::linalg; 35 36 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for 37 /// broadcasting. For example, 38 /// 39 /// ```mlir 40 /// #accesses = [ 41 /// affine_map<(d0, d1) -> (0, d1)>, 42 /// affine_map<(d0, d1) -> (d0, 0)>, 43 /// affine_map<(d0, d1) -> (d0, d1)> 44 /// ] 45 /// 46 /// #trait = { 47 /// args_in = 2, 48 /// args_out = 1, 49 /// indexing_maps = #accesses, 50 /// iterator_types = ["parallel", "parallel"], 51 /// library_call = "some_external_fn" 52 /// } 53 /// 54 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 55 /// tensor<5x5xf32> 56 /// { 57 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : 58 /// tensor<5xf32> into tensor<1x5xf32> 59 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : 60 /// tensor<5xf32> into tensor<5x1xf32> 61 /// %2 = linalg.generic #trait %0, %1 { 62 /// ^bb0(%arg2: f32, %arg3: f32): 63 /// %3 = addf %arg2, %arg3 : f32 64 /// linalg.yield %3 : f32 65 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> 66 /// return %2 : tensor<5x5xf32> 67 /// } 68 /// 69 /// would canonicalize to 70 /// 71 /// ```mlir 72 /// #accesses = [ 73 /// affine_map<(d0, d1) -> (d1)>, 74 /// affine_map<(d0, d1) -> (d0)>, 75 /// affine_map<(d0, d1) -> (d0, d1)> 76 /// ] 77 /// 78 /// #trait = { 79 /// args_in = 2, 80 /// args_out = 1, 81 /// indexing_maps = #accesses, 82 /// iterator_types = ["parallel", "parallel"], 83 /// library_call = "some_external_fn" 84 /// } 85 /// 86 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 87 /// tensor<5x5xf32> 88 /// { 89 /// %0 = linalg.generic #trait %arg0, %arg1 { 90 /// ^bb0(%arg2: f32, %arg3: f32): 91 /// %3 = addf %arg2, %arg3 : f32 92 /// linalg.yield %3 : f32 93 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> 94 /// return %0 : tensor<5x5xf32> 95 /// } 96 97 /// Given dims of the iteration space of a structured op that are known to be 98 /// single trip count (`unitDims`), return the indexing maps to use in the 99 /// canonicalized op with these dims removed, given the original `indexingMaps`. 100 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims, 101 ArrayRef<AffineMap> indexingMaps, 102 MLIRContext *context) { 103 if (indexingMaps.empty()) 104 return nullptr; 105 unsigned numIterationDims = indexingMaps.front().getNumDims(); 106 unsigned numSymbols = indexingMaps.front().getNumSymbols(); 107 108 // Compute the replacement for each dim expr. 109 SmallVector<AffineExpr, 4> dimReplacements; 110 dimReplacements.reserve(numIterationDims); 111 unsigned numKeptDims = 0; 112 for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) { 113 if (unitDims.count(dim)) 114 dimReplacements.push_back(getAffineConstantExpr(0, context)); 115 else 116 dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context)); 117 } 118 119 // Symbols remain the same. 120 SmallVector<AffineExpr, 4> symReplacements; 121 symReplacements.reserve(numSymbols); 122 for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols)) 123 symReplacements.push_back(getAffineSymbolExpr(symbol, context)); 124 125 SmallVector<AffineMap, 4> newIndexingMaps; 126 newIndexingMaps.reserve(indexingMaps.size()); 127 for (AffineMap operandMap : indexingMaps) { 128 // Expected indexing maps to have no symbols. 129 if (operandMap.getNumSymbols()) 130 return nullptr; 131 newIndexingMaps.push_back(simplifyAffineMap( 132 operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements, 133 numIterationDims - unitDims.size(), 134 numSymbols))); 135 } 136 137 // Check that the new index maps are invertible. If not, something went 138 // wrong, so abort. 139 if (!inversePermutation(concatAffineMaps(newIndexingMaps))) 140 return nullptr; 141 return ArrayAttr::get(context, 142 llvm::to_vector<4>(llvm::map_range( 143 newIndexingMaps, [](AffineMap map) -> Attribute { 144 return AffineMapAttr::get(map); 145 }))); 146 } 147 148 /// Modify the region of indexed generic op to drop arguments corresponding to 149 /// loops that are unit trip count. 150 template <typename OpTy> 151 static LogicalResult 152 replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet<unsigned> &unitDims, 153 PatternRewriter &rewriterp) { 154 return success(); 155 } 156 157 template <> 158 LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>( 159 IndexedGenericOp op, const DenseSet<unsigned> &unitDims, 160 PatternRewriter &rewriter) { 161 OpBuilder::InsertionGuard guard(rewriter); 162 Block *entryBlock = &op->getRegion(0).front(); 163 rewriter.setInsertionPointToStart(entryBlock); 164 Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0); 165 for (unsigned unitDimLoop : unitDims) { 166 entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero); 167 } 168 SmallVector<unsigned, 8> unitDimsToErase(unitDims.begin(), unitDims.end()); 169 entryBlock->eraseArguments(unitDimsToErase); 170 return success(); 171 } 172 173 namespace { 174 /// Pattern to fold unit-trip count loops in GenericOps. 175 template <typename GenericOpTy> 176 struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> { 177 using OpRewritePattern<GenericOpTy>::OpRewritePattern; 178 LogicalResult matchAndRewrite(GenericOpTy op, 179 PatternRewriter &rewriter) const override { 180 // TODO: remove once index ops are supported. 181 if (op.hasIndexSemantics()) 182 return failure(); 183 184 SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps(); 185 if (indexingMaps.empty()) 186 return failure(); 187 188 // Check if any of the iteration dimensions are unit-trip count. They will 189 // end up being unit-trip count if they are used to index into a unit-dim 190 // tensor/memref. 191 AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); 192 if (!invertedMap) 193 return failure(); 194 SmallVector<int64_t, 4> dims; 195 for (ShapedType shapedType : op.getShapedOperandTypes()) 196 dims.append(shapedType.getShape().begin(), shapedType.getShape().end()); 197 198 // Find all the reduction iterators. Those need some special consideration 199 // (see below). 200 auto getLoopDimsOfType = 201 [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> { 202 SmallVector<AffineExpr> dimExprs; 203 getDimsOfType(op, iteratorTypeName, dimExprs); 204 return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) { 205 return expr.cast<AffineDimExpr>().getPosition(); 206 })); 207 }; 208 auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName()); 209 210 DenseSet<unsigned> unitDims; 211 SmallVector<unsigned, 4> unitDimsReductionLoops; 212 ArrayAttr iteratorTypes = op.iterator_types(); 213 for (auto expr : enumerate(invertedMap.getResults())) { 214 if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>()) 215 if (dims[dimExpr.getPosition()] == 1) { 216 if (isParallelIterator(iteratorTypes[expr.index()])) 217 unitDims.insert(expr.index()); 218 else if (isReductionIterator(iteratorTypes[expr.index()])) 219 unitDimsReductionLoops.push_back(expr.index()); 220 } 221 } 222 223 // Reduction loops can be dropped if there is at least one other reduction 224 // loop that is not dropped. This accounts for the initial value read in the 225 // reduction loop. 226 if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) { 227 if (unitDimsReductionLoops.size() == reductionDims.size()) 228 unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end())); 229 else 230 unitDims.insert(unitDimsReductionLoops.begin(), 231 unitDimsReductionLoops.end()); 232 } 233 234 if (unitDims.empty()) 235 return failure(); 236 237 // Compute the modified indexing maps. 238 MLIRContext *context = rewriter.getContext(); 239 ArrayAttr newIndexingMapAttr = 240 replaceUnitDims(unitDims, indexingMaps, context); 241 if (!newIndexingMapAttr) 242 return op.emitError("unable to compute modified indexing_maps"); 243 244 // Compute the iterator types of the modified op by dropping the one-trip 245 // count loops. 246 SmallVector<Attribute, 4> newIteratorTypes; 247 for (auto attr : llvm::enumerate(iteratorTypes)) { 248 if (!unitDims.count(attr.index())) 249 newIteratorTypes.push_back(attr.value()); 250 } 251 252 rewriter.startRootUpdate(op); 253 op.indexing_mapsAttr(newIndexingMapAttr); 254 op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes)); 255 (void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter); 256 rewriter.finalizeRootUpdate(op); 257 return success(); 258 } 259 }; 260 261 struct UnitExtentReplacementInfo { 262 RankedTensorType type; 263 AffineMap indexMap; 264 ArrayAttr reassociation; 265 }; 266 } // namespace 267 268 /// Utility function for replacing operands/results to a linalg generic 269 /// operation on tensors with unit-extent dimensions. These can be replaced with 270 /// an operand/result with the unit-extent dimension removed. This is only done 271 /// if the indexing map used to access that didimensionmension has a 272 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a 273 /// Linalg op, and its `indexMap` the utility function returns: 274 /// - the new type with dimensions of size 1 removed. 275 /// - modified index map that can be used to access the replaced result/operand 276 /// - the reassociation that converts from the original tensor type to the 277 /// modified tensor type. 278 static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap, 279 RankedTensorType type, 280 MLIRContext *context) { 281 ArrayRef<int64_t> shape = type.getShape(); 282 ArrayRef<AffineExpr> exprs = indexMap.getResults(); 283 SmallVector<AffineExpr, 2> reassociations; 284 SmallVector<Attribute, 4> reassociationMaps; 285 SmallVector<AffineExpr, 4> newIndexExprs; 286 SmallVector<int64_t, 4> newShape; 287 288 int64_t origRank = type.getRank(); 289 AffineExpr zeroExpr = getAffineConstantExpr(0, context); 290 auto isUnitExtent = [&](int64_t dim) -> bool { 291 return shape[dim] == 1 && exprs[dim] == zeroExpr; 292 }; 293 294 unsigned dim = 0; 295 // Fold dimensions that are unit-extent at the beginning of the tensor. 296 while (dim < origRank && isUnitExtent(dim)) 297 reassociations.push_back(getAffineDimExpr(dim++, context)); 298 while (dim < origRank) { 299 reassociations.push_back(getAffineDimExpr(dim, context)); 300 newIndexExprs.push_back(exprs[dim]); 301 newShape.push_back(shape[dim]); 302 // Fold all following dimensions that are unit-extent. 303 while (dim + 1 < origRank && isUnitExtent(dim + 1)) { 304 ++dim; 305 reassociations.push_back(getAffineDimExpr(dim, context)); 306 } 307 reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get( 308 origRank, /*numSymbols = */ 0, reassociations, context))); 309 reassociations.clear(); 310 ++dim; 311 } 312 UnitExtentReplacementInfo info = { 313 RankedTensorType::get(newShape, type.getElementType()), 314 AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(), 315 newIndexExprs, context), 316 ArrayAttr::get(context, reassociationMaps)}; 317 return info; 318 } 319 320 namespace { 321 322 /// Pattern to replace tensors operands/results that are unit extents. 323 template <typename GenericOpTy> 324 struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> { 325 using OpRewritePattern<GenericOpTy>::OpRewritePattern; 326 LogicalResult matchAndRewrite(GenericOpTy op, 327 PatternRewriter &rewriter) const override { 328 // TODO: remove once index ops are supported. 329 if (op.hasIndexSemantics()) 330 return failure(); 331 332 if (!op.hasTensorSemantics()) 333 return failure(); 334 335 MLIRContext *context = rewriter.getContext(); 336 Location loc = op.getLoc(); 337 338 SmallVector<AffineMap, 4> newIndexingMaps; 339 SmallVector<ArrayAttr, 4> reassociationMaps; 340 SmallVector<ShapedType, 4> newInputOutputTypes; 341 bool doCanonicalization = false; 342 for (auto it : 343 llvm::zip(op.getIndexingMaps(), op.getShapedOperandTypes())) { 344 auto replacementInfo = replaceUnitExtents( 345 std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(), 346 context); 347 reassociationMaps.push_back(replacementInfo.reassociation); 348 newIndexingMaps.push_back(replacementInfo.indexMap); 349 newInputOutputTypes.push_back(replacementInfo.type); 350 doCanonicalization |= replacementInfo.type != std::get<1>(it); 351 } 352 353 // If the indexing maps of the result operation are not invertible (i.e. not 354 // legal), abort. 355 if (!doCanonicalization || 356 !inversePermutation(concatAffineMaps(newIndexingMaps))) 357 return failure(); 358 359 // If any operand type change, insert a reshape to convert from the original 360 // type to the new type. 361 // TODO: get rid of flattenedIdx which assumes operand order and contiguity. 362 unsigned flattenedIdx = 0; 363 auto insertReshapes = [&](ValueRange values) { 364 SmallVector<Value, 4> res; 365 res.reserve(values.size()); 366 for (auto operand : llvm::enumerate(values)) { 367 if (operand.value().getType() == newInputOutputTypes[flattenedIdx]) 368 res.push_back(operand.value()); 369 else 370 res.push_back(rewriter.create<linalg::TensorReshapeOp>( 371 loc, newInputOutputTypes[flattenedIdx], operand.value(), 372 reassociationMaps[flattenedIdx])); 373 ++flattenedIdx; 374 } 375 return res; 376 }; 377 378 SmallVector<Value, 4> newInputs = insertReshapes(op.inputs()); 379 SmallVector<Value, 4> newOutputs = insertReshapes(op.outputs()); 380 381 // If any result type changes, insert a reshape to convert from the original 382 // type to the new type. 383 SmallVector<Type, 4> resultTypes; 384 resultTypes.reserve(op.getNumResults()); 385 for (unsigned i : llvm::seq<unsigned>(0, op.getNumResults())) 386 resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]); 387 GenericOpTy replacementOp = rewriter.create<GenericOpTy>( 388 loc, resultTypes, newInputs, newOutputs, newIndexingMaps, 389 llvm::to_vector<4>( 390 op.iterator_types().template getAsValueRange<StringAttr>())); 391 rewriter.inlineRegionBefore(op.region(), replacementOp.region(), 392 replacementOp.region().begin()); 393 394 // If any result tensor has a modified shape, then add reshape to recover 395 // the original shape. 396 SmallVector<Value, 4> resultReplacements; 397 for (auto result : llvm::enumerate(replacementOp.getResults())) { 398 unsigned index = result.index() + replacementOp.getNumInputs(); 399 RankedTensorType origResultType = op.getResult(result.index()) 400 .getType() 401 .template cast<RankedTensorType>(); 402 if (origResultType != result.value().getType()) 403 resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>( 404 loc, origResultType, result.value(), reassociationMaps[index])); 405 else 406 resultReplacements.push_back(result.value()); 407 } 408 rewriter.replaceOp(op, resultReplacements); 409 return success(); 410 } 411 }; 412 413 /// Pattern to fold pair of reshape ops where the intermediate has unit-dims for 414 /// example: 415 /// 416 /// %0 = linalg.tensor_reshape %arg0 417 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] 418 /// : tensor<2048xf32> into tensor<1x4x1x512xf32> 419 /// %1 = linalg.tensor_reshape %0 420 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, 421 /// affine_map<(d0, d1, d2, d3) -> (d3)>] 422 /// : tensor<1x4x1x512xf32> into tensor<4x512xf32> 423 /// 424 /// can be replaced with 425 /// 426 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] 427 /// : tensor<2048xf32> into tensor<4x512xf32> 428 /// 429 /// Similarly, 430 /// 431 /// %0 = linalg.tensor_reshape %arg0 432 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, 433 /// affine_map<(d0, d1, d2, d3) -> (d3)>] 434 /// : tensor<4x512xf32> into tensor<1x4x1x512xf32> 435 /// %1 = linalg.tensor_reshape %0 436 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] 437 /// : tensor<1x4x1x512xf32> into tensor<2048xf32> 438 /// 439 /// can be replaced with 440 /// 441 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] 442 /// : tensor<4x512xf32> into tensor<2048xf32> 443 struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> { 444 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 445 446 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 447 PatternRewriter &rewriter) const override { 448 // Check that the source operand is created from a reshape as well. 449 TensorReshapeOp parentReshapeOp = 450 reshapeOp.src().getDefiningOp<TensorReshapeOp>(); 451 if (!parentReshapeOp) 452 return failure(); 453 454 RankedTensorType srcType = reshapeOp.getSrcType(), 455 dstType = reshapeOp.getResultType(), 456 parentSrcType = parentReshapeOp.getSrcType(); 457 if (!srcType.hasStaticShape() || !dstType.hasStaticShape() || 458 !parentSrcType.hasStaticShape() || 459 srcType.getRank() < dstType.getRank() || 460 parentSrcType.getRank() == dstType.getRank()) 461 return failure(); 462 463 // Check if the result tensor_reshape is folding or expanding after folding 464 // the reshapeOp and parentReshapeOp are combined. If the final 465 // tensor_reshape is folding, the parentReshapeOp is introducing unit-dims, 466 // and the reshapeOp does an actual reshape. If the final tensor_reshape op 467 // is expanding, the reshapeOp is introducing unit-dims, and the 468 // parentReshapeOp does an actual reshape. 469 bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank(); 470 ArrayRef<int64_t> expandedShape = 471 isFoldingPattern ? parentSrcType.getShape() : dstType.getShape(); 472 ArrayRef<int64_t> foldedShape = 473 isFoldingPattern ? dstType.getShape() : parentSrcType.getShape(); 474 475 unsigned expandedDim = 0, foldedDim = 0; 476 SmallVector<SmallVector<AffineExpr, 4>, 4> reassociationExprs( 477 foldedShape.size()); 478 while (expandedDim < expandedShape.size() && 479 foldedDim < foldedShape.size()) { 480 int64_t dstSize = foldedShape[foldedDim]; 481 int64_t srcSize = expandedShape[expandedDim]; 482 while (srcSize < dstSize && expandedDim < expandedShape.size()) { 483 reassociationExprs[foldedDim].push_back( 484 rewriter.getAffineDimExpr(expandedDim++)); 485 srcSize *= expandedShape[expandedDim]; 486 } 487 if (srcSize == dstSize) { 488 reassociationExprs[foldedDim].push_back( 489 rewriter.getAffineDimExpr(expandedDim++)); 490 // If the next dim in foldedShape is not 1, treat subsequent dims in 491 // expandedShape which are 1 to be collapsed. 492 if (foldedDim == foldedShape.size() - 1 || 493 foldedShape[foldedDim + 1] != 1) { 494 while (expandedDim < expandedShape.size() && 495 expandedShape[expandedDim] == 1) { 496 reassociationExprs[foldedDim].push_back( 497 rewriter.getAffineDimExpr(expandedDim++)); 498 } 499 } 500 } else { 501 return failure(); 502 } 503 foldedDim++; 504 } 505 if (expandedDim != expandedShape.size()) 506 return failure(); 507 508 SmallVector<AffineMap, 4> reassociationMaps = 509 llvm::to_vector<4>(llvm::map_range( 510 reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap { 511 return AffineMap::get(expandedShape.size(), 0, exprs, 512 rewriter.getContext()); 513 })); 514 rewriter.replaceOpWithNewOp<TensorReshapeOp>( 515 reshapeOp, dstType, parentReshapeOp.src(), 516 rewriter.getAffineMapArrayAttr(reassociationMaps)); 517 return success(); 518 } 519 }; 520 521 /// Pattern to fold subtensors that are just taking a slice of unit-dimension 522 /// tensor. For example 523 /// 524 /// %1 = subtensor %0[0, %o1, 0] [1, %s1, 1] [1, 1, 1] 525 /// : tensor<1x?x1xf32> to tensor<1x?x1xf32> 526 /// 527 /// can be replaced with 528 /// 529 /// %0 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] 530 /// : tensor<1x?x1xf32> into tensor<?xf32> 531 /// %1 = subtensor %0[%o1] [%s1] [1] : tensor<?xf32> to tensor<?xf32> 532 /// %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] 533 /// : tensor<?xf32> into tensor<1x?x1xf32> 534 /// 535 /// The additional tensor_reshapes will hopefully get canonicalized away with 536 /// other reshapes that drop unit dimensions. Three condiitions to fold a 537 /// dimension 538 /// - The offset must be 0 539 /// - The size must be 1 540 /// - The dimension of the source type must be 1. 541 struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> { 542 using OpRewritePattern<SubTensorOp>::OpRewritePattern; 543 544 LogicalResult matchAndRewrite(SubTensorOp subTensorOp, 545 PatternRewriter &rewriter) const override { 546 SmallVector<OpFoldResult> mixedOffsets = subTensorOp.getMixedOffsets(); 547 SmallVector<OpFoldResult> mixedSizes = subTensorOp.getMixedSizes(); 548 SmallVector<OpFoldResult> mixedStrides = subTensorOp.getMixedStrides(); 549 auto hasValue = [](OpFoldResult valueOrAttr, int64_t val) { 550 auto attr = valueOrAttr.dyn_cast<Attribute>(); 551 return attr && attr.cast<IntegerAttr>().getInt() == val; 552 }; 553 554 if (llvm::any_of(mixedStrides, [&](OpFoldResult valueOrAttr) { 555 return !hasValue(valueOrAttr, 1); 556 })) 557 return failure(); 558 559 // Find the expanded unit dimensions. 560 SmallVector<ReassociationIndices> reassociation; 561 SmallVector<OpFoldResult> newOffsets, newSizes; 562 ArrayRef<int64_t> sourceShape = subTensorOp.getSourceType().getShape(); 563 ReassociationIndices curr; 564 for (int64_t dim : llvm::seq<int64_t>(0, mixedOffsets.size())) { 565 curr.push_back(dim); 566 if (sourceShape[dim] == 1 && hasValue(mixedOffsets[dim], 0) && 567 hasValue(mixedSizes[dim], 1)) { 568 continue; 569 } 570 newOffsets.push_back(mixedOffsets[dim]); 571 newSizes.push_back(mixedSizes[dim]); 572 reassociation.emplace_back(ReassociationIndices{}); 573 std::swap(reassociation.back(), curr); 574 } 575 if (newOffsets.size() == mixedOffsets.size()) 576 return failure(); 577 reassociation.back().append(curr.begin(), curr.end()); 578 SmallVector<OpFoldResult> newStrides(newOffsets.size(), 579 rewriter.getI64IntegerAttr(1)); 580 Location loc = subTensorOp->getLoc(); 581 auto srcReshape = rewriter.create<TensorReshapeOp>( 582 loc, subTensorOp.source(), reassociation); 583 auto newSubTensorOp = rewriter.create<SubTensorOp>( 584 loc, srcReshape, newOffsets, newSizes, newStrides); 585 rewriter.replaceOpWithNewOp<TensorReshapeOp>( 586 subTensorOp, subTensorOp.getType(), newSubTensorOp, reassociation); 587 return success(); 588 } 589 }; 590 591 } // namespace 592 593 /// Patterns that are used to canonicalize the use of unit-extent dims for 594 /// broadcasting. 595 void mlir::linalg::populateFoldUnitExtentDimsPatterns( 596 RewritePatternSet &patterns) { 597 auto *context = patterns.getContext(); 598 patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>, 599 FoldUnitDimSubTensorOp, ReplaceUnitExtentTensors<GenericOp>, 600 ReplaceUnitExtentTensors<IndexedGenericOp>>(context); 601 TensorReshapeOp::getCanonicalizationPatterns(patterns, context); 602 patterns.add<FoldReshapeOpWithUnitExtent>(context); 603 } 604 605 namespace { 606 /// Pass that removes unit-extent dims within generic ops. 607 struct LinalgFoldUnitExtentDimsPass 608 : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> { 609 void runOnFunction() override { 610 FuncOp funcOp = getFunction(); 611 MLIRContext *context = funcOp.getContext(); 612 RewritePatternSet patterns(context); 613 if (foldOneTripLoopsOnly) 614 patterns 615 .add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>( 616 context); 617 else 618 populateFoldUnitExtentDimsPatterns(patterns); 619 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 620 } 621 }; 622 } // namespace 623 624 std::unique_ptr<OperationPass<FuncOp>> 625 mlir::createLinalgFoldUnitExtentDimsPass() { 626 return std::make_unique<LinalgFoldUnitExtentDimsPass>(); 627 } 628