1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns/pass to remove usage of unit-extent dimensions 10 // to specify broadcasting in favor of more canonical representation of the 11 // computation 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Support/LLVM.h" 25 #include "mlir/Transforms/FoldUtils.h" 26 #include "llvm/Support/CommandLine.h" 27 #include "llvm/Support/Debug.h" 28 29 #define DEBUG_TYPE "linalg-drop-unit-dims" 30 31 using namespace mlir; 32 using namespace mlir::edsc; 33 using namespace mlir::edsc::intrinsics; 34 using namespace mlir::linalg; 35 36 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for 37 /// broadcasting. For example, 38 /// 39 /// ```mlir 40 /// #accesses = [ 41 /// affine_map<(d0, d1) -> (0, d1)>, 42 /// affine_map<(d0, d1) -> (d0, 0)>, 43 /// affine_map<(d0, d1) -> (d0, d1)> 44 /// ] 45 /// 46 /// #trait = { 47 /// args_in = 2, 48 /// args_out = 1, 49 /// indexing_maps = #accesses, 50 /// iterator_types = ["parallel", "parallel"], 51 /// library_call = "some_external_fn" 52 /// } 53 /// 54 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 55 /// tensor<5x5xf32> 56 /// { 57 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : 58 /// tensor<5xf32> into tensor<1x5xf32> 59 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : 60 /// tensor<5xf32> into tensor<5x1xf32> 61 /// %2 = linalg.generic #trait %0, %1 { 62 /// ^bb0(%arg2: f32, %arg3: f32): 63 /// %3 = addf %arg2, %arg3 : f32 64 /// linalg.yield %3 : f32 65 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> 66 /// return %2 : tensor<5x5xf32> 67 /// } 68 /// 69 /// would canonicalize to 70 /// 71 /// ```mlir 72 /// #accesses = [ 73 /// affine_map<(d0, d1) -> (d1)>, 74 /// affine_map<(d0, d1) -> (d0)>, 75 /// affine_map<(d0, d1) -> (d0, d1)> 76 /// ] 77 /// 78 /// #trait = { 79 /// args_in = 2, 80 /// args_out = 1, 81 /// indexing_maps = #accesses, 82 /// iterator_types = ["parallel", "parallel"], 83 /// library_call = "some_external_fn" 84 /// } 85 /// 86 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 87 /// tensor<5x5xf32> 88 /// { 89 /// %0 = linalg.generic #trait %arg0, %arg1 { 90 /// ^bb0(%arg2: f32, %arg3: f32): 91 /// %3 = addf %arg2, %arg3 : f32 92 /// linalg.yield %3 : f32 93 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> 94 /// return %0 : tensor<5x5xf32> 95 /// } 96 97 /// Given dims of the iteration space of a structured op that are known to be 98 /// single trip count (`unitDims`), return the indexing maps to use in the 99 /// canonicalized op with these dims removed, given the original `indexingMaps`. 100 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims, 101 ArrayRef<AffineMap> indexingMaps, 102 MLIRContext *context) { 103 if (indexingMaps.empty()) 104 return nullptr; 105 unsigned numIterationDims = indexingMaps.front().getNumDims(); 106 unsigned numSymbols = indexingMaps.front().getNumSymbols(); 107 108 // Compute the replacement for each dim expr. 109 SmallVector<AffineExpr, 4> dimReplacements; 110 dimReplacements.reserve(numIterationDims); 111 unsigned numKeptDims = 0; 112 for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) { 113 if (unitDims.count(dim)) 114 dimReplacements.push_back(getAffineConstantExpr(0, context)); 115 else 116 dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context)); 117 } 118 119 // Symbols remain the same. 120 SmallVector<AffineExpr, 4> symReplacements; 121 symReplacements.reserve(numSymbols); 122 for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols)) 123 symReplacements.push_back(getAffineSymbolExpr(symbol, context)); 124 125 SmallVector<AffineMap, 4> newIndexingMaps; 126 newIndexingMaps.reserve(indexingMaps.size()); 127 for (AffineMap operandMap : indexingMaps) { 128 // Expected indexing maps to have no symbols. 129 if (operandMap.getNumSymbols()) 130 return nullptr; 131 newIndexingMaps.push_back(simplifyAffineMap( 132 operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements, 133 numIterationDims - unitDims.size(), 134 numSymbols))); 135 } 136 137 // Check that the new index maps are invertible. If not, something went 138 // wrong, so abort. 139 if (!inversePermutation(concatAffineMaps(newIndexingMaps))) 140 return nullptr; 141 return ArrayAttr::get( 142 llvm::to_vector<4>(llvm::map_range( 143 newIndexingMaps, 144 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })), 145 context); 146 } 147 148 namespace { 149 /// Pattern to fold unit-trip count loops in GenericOps. 150 // TODO: Generalize this to indexed-generic as well by modifying the region args 151 // as well. 152 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> { 153 using OpRewritePattern<GenericOp>::OpRewritePattern; 154 LogicalResult matchAndRewrite(GenericOp genericOp, 155 PatternRewriter &rewriter) const override { 156 SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps(); 157 if (indexingMaps.empty()) 158 return failure(); 159 160 // Check if any of the iteration dimensions are unit-trip count. They will 161 // end up being unit-trip count if they are used to index into a unit-dim 162 // tensor/memref. 163 AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); 164 if (!invertedMap) 165 return failure(); 166 SmallVector<int64_t, 4> dims; 167 for (ShapedType shapedType : genericOp.getInputOutputShapedTypes()) 168 dims.append(shapedType.getShape().begin(), shapedType.getShape().end()); 169 DenseSet<unsigned> unitDims; 170 ArrayAttr iteratorTypes = genericOp.iterator_types(); 171 for (auto expr : enumerate(invertedMap.getResults())) { 172 if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>()) 173 if (dims[dimExpr.getPosition()] == 1 && 174 iteratorTypes[expr.index()].dyn_cast<StringAttr>().getValue() == 175 getParallelIteratorTypeName()) 176 unitDims.insert(expr.index()); 177 } 178 if (unitDims.empty()) 179 return failure(); 180 181 // Compute the modified indexing maps. 182 MLIRContext *context = rewriter.getContext(); 183 ArrayAttr newIndexingMapAttr = 184 replaceUnitDims(unitDims, indexingMaps, context); 185 if (!newIndexingMapAttr) 186 return genericOp.emitError("unable to compute modified indexing_maps"); 187 188 // Compute the iterator types of the modified op by dropping the one-trip 189 // count loops. 190 SmallVector<Attribute, 4> newIteratorTypes; 191 for (auto attr : llvm::enumerate(iteratorTypes)) { 192 if (!unitDims.count(attr.index())) 193 newIteratorTypes.push_back(attr.value()); 194 } 195 196 rewriter.startRootUpdate(genericOp); 197 genericOp.indexing_mapsAttr(newIndexingMapAttr); 198 genericOp.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context)); 199 rewriter.finalizeRootUpdate(genericOp); 200 return success(); 201 } 202 }; 203 204 struct UnitExtentReplacementInfo { 205 RankedTensorType type; 206 AffineMap indexMap; 207 ArrayAttr reassociation; 208 }; 209 } // namespace 210 211 /// Utility function for replacing operands/results to a linalg generic 212 /// operation on tensors with unit-extent dimensions. These can be replaced with 213 /// an operand/result with the unit-extent dimension removed. This is only done 214 /// if the indexing map used to access that didimensionmension has a 215 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a 216 /// Linalg op, and its `indexMap` the utility function returns: 217 /// - the new type with dimensions of size 1 removed. 218 /// - modified index map that can be used to access the replaced result/operand 219 /// - the reassociation that converts from the original tensor type to the 220 /// modified tensor type. 221 static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap, 222 RankedTensorType type, 223 MLIRContext *context) { 224 ArrayRef<int64_t> shape = type.getShape(); 225 ArrayRef<AffineExpr> exprs = indexMap.getResults(); 226 SmallVector<AffineExpr, 2> reassociations; 227 SmallVector<Attribute, 4> reassociationMaps; 228 SmallVector<AffineExpr, 4> newIndexExprs; 229 SmallVector<int64_t, 4> newShape; 230 231 int64_t origRank = type.getRank(); 232 AffineExpr zeroExpr = getAffineConstantExpr(0, context); 233 auto isUnitExtent = [&](int64_t dim) -> bool { 234 return shape[dim] == 1 && exprs[dim] == zeroExpr; 235 }; 236 237 unsigned dim = 0; 238 // Fold dimensions that are unit-extent at the beginning of the tensor. 239 while (dim < origRank && isUnitExtent(dim)) 240 reassociations.push_back(getAffineDimExpr(dim++, context)); 241 while (dim < origRank) { 242 reassociations.push_back(getAffineDimExpr(dim, context)); 243 newIndexExprs.push_back(exprs[dim]); 244 newShape.push_back(shape[dim]); 245 // Fold all following dimensions that are unit-extent. 246 while (dim + 1 < origRank && isUnitExtent(dim + 1)) { 247 ++dim; 248 reassociations.push_back(getAffineDimExpr(dim, context)); 249 } 250 reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get( 251 origRank, /*numSymbols = */ 0, reassociations, context))); 252 reassociations.clear(); 253 ++dim; 254 } 255 UnitExtentReplacementInfo info = { 256 RankedTensorType::get(newShape, type.getElementType()), 257 AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(), 258 newIndexExprs, context), 259 ArrayAttr::get(reassociationMaps, context)}; 260 return info; 261 } 262 263 namespace { 264 265 /// Pattern to replace tensors operands/results that are unit extents. 266 struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> { 267 using OpRewritePattern<GenericOp>::OpRewritePattern; 268 LogicalResult matchAndRewrite(GenericOp genericOp, 269 PatternRewriter &rewriter) const override { 270 // TODO: support init_tensors and reductions. 271 if (!genericOp.hasTensorSemantics() || !genericOp.init_tensors().empty()) 272 return failure(); 273 274 MLIRContext *context = rewriter.getContext(); 275 Location loc = genericOp.getLoc(); 276 277 SmallVector<AffineMap, 4> newIndexingMaps; 278 SmallVector<ArrayAttr, 4> reassociationMaps; 279 SmallVector<ShapedType, 4> newInputOutputTypes; 280 bool doCanonicalization = false; 281 for (auto it : llvm::zip(genericOp.getIndexingMaps(), 282 genericOp.getInputOutputShapedTypes())) { 283 auto replacementInfo = replaceUnitExtents( 284 std::get<0>(it), std::get<1>(it).cast<RankedTensorType>(), context); 285 reassociationMaps.push_back(replacementInfo.reassociation); 286 newIndexingMaps.push_back(replacementInfo.indexMap); 287 newInputOutputTypes.push_back(replacementInfo.type); 288 doCanonicalization |= replacementInfo.type != std::get<1>(it); 289 } 290 291 // If the indexing maps of the result operation are not invertible (i.e. not 292 // legal), abort. 293 if (!doCanonicalization || 294 !inversePermutation(concatAffineMaps(newIndexingMaps))) 295 return failure(); 296 297 // If any operand type change, insert a reshape to convert from the original 298 // type to the new type. 299 // TODO: get rid of flattenedIdx which assumes operand order and contiguity. 300 unsigned flattenedIdx = 0; 301 auto insertReshapes = [&](ValueRange values) { 302 SmallVector<Value, 4> res; 303 res.reserve(values.size()); 304 for (auto operand : llvm::enumerate(values)) { 305 if (operand.value().getType() == newInputOutputTypes[flattenedIdx]) 306 res.push_back(operand.value()); 307 else 308 res.push_back(rewriter.create<linalg::TensorReshapeOp>( 309 loc, newInputOutputTypes[flattenedIdx], operand.value(), 310 reassociationMaps[flattenedIdx])); 311 ++flattenedIdx; 312 } 313 return res; 314 }; 315 316 SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs()); 317 SmallVector<Value, 4> newOutputBuffers = 318 insertReshapes(genericOp.output_buffers()); 319 SmallVector<Value, 4> newInitTensors = 320 insertReshapes(genericOp.init_tensors()); 321 322 // If any result type change, insert a reshape to convert from the original 323 // type to the new type. 324 SmallVector<Type, 4> resultTypes; 325 resultTypes.reserve(genericOp.getNumResults()); 326 for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults())) 327 resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]); 328 GenericOp replacementOp = rewriter.create<GenericOp>( 329 loc, resultTypes, newInputs, newOutputBuffers, newInitTensors, 330 newIndexingMaps, 331 llvm::to_vector<4>( 332 genericOp.iterator_types().getAsValueRange<StringAttr>())); 333 rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(), 334 replacementOp.region().begin()); 335 336 // If any result tensor has a modified shape, then add reshape to recover 337 // the original shape. 338 SmallVector<Value, 4> resultReplacements; 339 for (auto result : llvm::enumerate(replacementOp.getResults())) { 340 unsigned index = result.index() + replacementOp.getNumOperands(); 341 RankedTensorType origResultType = genericOp.getResult(result.index()) 342 .getType() 343 .cast<RankedTensorType>(); 344 if (origResultType != result.value().getType()) 345 resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>( 346 loc, origResultType, result.value(), reassociationMaps[index])); 347 else 348 resultReplacements.push_back(result.value()); 349 } 350 rewriter.replaceOp(genericOp, resultReplacements); 351 return success(); 352 } 353 }; 354 } // namespace 355 356 namespace { 357 /// Pattern to fold pair of reshape ops where the intermediate has unit-dims for 358 /// example: 359 /// 360 /// %0 = linalg.tensor_reshape %arg0 361 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] 362 /// : tensor<2048xf32> into tensor<1x4x1x512xf32> 363 /// %1 = linalg.tensor_reshape %0 364 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, 365 /// affine_map<(d0, d1, d2, d3) -> (d3)>] 366 /// : tensor<1x4x1x512xf32> into tensor<4x512xf32> 367 /// 368 /// can be replaced with 369 /// 370 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] 371 /// : tensor<2048xf32> into tensor<4x512xf32> 372 /// 373 /// Similarly, 374 /// 375 /// %0 = linalg.tensor_reshape %arg0 376 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, 377 /// affine_map<(d0, d1, d2, d3) -> (d3)>] 378 /// : tensor<4x512xf32> into tensor<1x4x1x512xf32> 379 /// %1 = linalg.tensor_reshape %0 380 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] 381 /// : tensor<1x4x1x512xf32> into tensor<2048xf32> 382 /// 383 /// can be replaced with 384 /// 385 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] 386 /// : tensor<4x512xf32> into tensor<2048xf32> 387 struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> { 388 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 389 390 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 391 PatternRewriter &rewriter) const override { 392 // Check that the source operand is created from a reshape as well. 393 TensorReshapeOp parentReshapeOp = 394 reshapeOp.src().getDefiningOp<TensorReshapeOp>(); 395 if (!parentReshapeOp) 396 return failure(); 397 398 RankedTensorType srcType = reshapeOp.getSrcType(), 399 dstType = reshapeOp.getResultType(), 400 parentSrcType = parentReshapeOp.getSrcType(); 401 if (!srcType.hasStaticShape() || !dstType.hasStaticShape() || 402 !parentSrcType.hasStaticShape() || 403 srcType.getRank() < dstType.getRank() || 404 parentSrcType.getRank() == dstType.getRank()) 405 return failure(); 406 407 // Check if the result tensor_reshape after folding the reshapeOp and 408 // parentReshapeOp are combined. 409 // If the final tensor_reshape is folding, the parentReshapeOp is 410 // introducing unit-dims, and the reshapeOp does an actual reshape. 411 // If the final tensor_reshape op is expanding, the reshapeOp is 412 // introducing unit-dims, and the parentReshapeOp does an actual reshape. 413 bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank(); 414 ArrayRef<int64_t> expandedShape = 415 isFoldingPattern ? parentSrcType.getShape() : dstType.getShape(); 416 ArrayRef<int64_t> foldedShape = 417 isFoldingPattern ? dstType.getShape() : parentSrcType.getShape(); 418 419 unsigned expandedDim = 0, foldedDim = 0; 420 SmallVector<SmallVector<AffineExpr, 4>, 4> reassociationExprs( 421 foldedShape.size()); 422 while (expandedDim < expandedShape.size() && 423 foldedDim < foldedShape.size()) { 424 int64_t dstSize = foldedShape[foldedDim]; 425 int64_t srcSize = expandedShape[expandedDim]; 426 while (srcSize < dstSize && expandedDim < expandedShape.size()) { 427 reassociationExprs[foldedDim].push_back( 428 rewriter.getAffineDimExpr(expandedDim++)); 429 srcSize *= expandedShape[expandedDim]; 430 } 431 if (srcSize == dstSize) { 432 reassociationExprs[foldedDim].push_back( 433 rewriter.getAffineDimExpr(expandedDim++)); 434 // If the next dim in foldedShape is not 1, treat subsequent dims in 435 // expandedShape which are 1 to be collapsed. 436 if (foldedDim == foldedShape.size() - 1 || 437 foldedShape[foldedDim + 1] != 1) { 438 while (expandedDim < expandedShape.size() && 439 expandedShape[expandedDim] == 1) { 440 reassociationExprs[foldedDim].push_back( 441 rewriter.getAffineDimExpr(expandedDim++)); 442 } 443 } 444 } else { 445 return failure(); 446 } 447 foldedDim++; 448 } 449 if (expandedDim != expandedShape.size()) 450 return failure(); 451 452 SmallVector<AffineMap, 4> reassociationMaps = 453 llvm::to_vector<4>(llvm::map_range( 454 reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap { 455 return AffineMap::get(expandedShape.size(), 0, exprs, 456 rewriter.getContext()); 457 })); 458 rewriter.replaceOpWithNewOp<TensorReshapeOp>( 459 reshapeOp, dstType, parentReshapeOp.src(), 460 rewriter.getAffineMapArrayAttr(reassociationMaps)); 461 return success(); 462 } 463 }; 464 } // namespace 465 466 /// Patterns that are used to canonicalize the use of unit-extent dims for 467 /// broadcasting. 468 void mlir::populateLinalgFoldUnitExtentDimsPatterns( 469 MLIRContext *context, OwningRewritePatternList &patterns) { 470 patterns.insert<FoldUnitDimLoops, ReplaceUnitExtentTensors>(context); 471 TensorReshapeOp::getCanonicalizationPatterns(patterns, context); 472 patterns.insert<FoldReshapeOpWithUnitExtent>(context); 473 } 474 475 namespace { 476 /// Pass that removes unit-extent dims within generic ops. 477 struct LinalgFoldUnitExtentDimsPass 478 : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> { 479 void runOnFunction() override { 480 OwningRewritePatternList patterns; 481 FuncOp funcOp = getFunction(); 482 MLIRContext *context = funcOp.getContext(); 483 if (foldOneTripLoopsOnly) 484 patterns.insert<FoldUnitDimLoops>(context); 485 else 486 populateLinalgFoldUnitExtentDimsPatterns(context, patterns); 487 applyPatternsAndFoldGreedily(funcOp.getBody(), patterns); 488 } 489 }; 490 } // namespace 491 492 std::unique_ptr<OperationPass<FuncOp>> 493 mlir::createLinalgFoldUnitExtentDimsPass() { 494 return std::make_unique<LinalgFoldUnitExtentDimsPass>(); 495 } 496