1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion on tensors operations pass. 10 // 11 //===----------------------------------------------------------------------===// 12 #include <utility> 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 28 using namespace mlir; 29 using namespace mlir::linalg; 30 31 //===---------------------------------------------------------------------===// 32 // Methods and patterns that fuse elementwise `linalg.generic` operations. 33 //===---------------------------------------------------------------------===// 34 35 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 36 /// the `producer` to use in the fused operation given the indexing map of the 37 /// result of the producer in the consumer. 38 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 39 OpOperand *producerOpOperand, AffineMap producerResultIndexMap, 40 AffineMap fusedConsumerArgIndexMap) { 41 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 42 // from consumer loop -> consumer arg tensor index/producer result tensor 43 // index. The fused loop is same as the consumer loop. For each producer arg 44 // the indexing map to be computed is a map from consumer loop -> producer 45 // arg tensor index. 46 // producerResultIndexMap is a map from producer loop -> tensor index. 47 // Compute the inverse to get map from tensor index -> producer loop. 48 // The inverse is a map from producer result tensor index -> producer loop. 49 AffineMap invProducerResultIndexMap = 50 inversePermutation(producerResultIndexMap); 51 assert(invProducerResultIndexMap && 52 "expected producer result indexing map to be invertible"); 53 54 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner()); 55 // argMap is a map from producer loop -> producer arg tensor index. 56 AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); 57 58 // Compose argMap with invProducerResultIndexMap to get a map from 59 // producer result tensor index -> producer arg tensor index. 60 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 61 62 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 63 // consumer loop/ fused loop -> producer arg tensor index. 64 return t1.compose(fusedConsumerArgIndexMap); 65 } 66 67 /// Conditions for elementwise fusion of generic operations. 68 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, 69 OpOperand *consumerOpOperand) { 70 // Producer and consumer must have tensor semantics. 71 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 72 return false; 73 74 // Verify that 75 // - the producer has all "parallel" iterator type. 76 if (producer.getNumParallelLoops() != producer.getNumLoops()) 77 return false; 78 79 // Only allow fusing the producer of an input operand for now. 80 // TODO: allow fusing the producer of an output operand. 81 if (!consumer.isInputTensor(consumerOpOperand)) 82 return false; 83 84 // Get the consumer index map. The number of results of the consumer index 85 // map must match the number of loops of the producer. 86 AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); 87 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 88 return false; 89 90 // Currently support only operations with single result. 91 if (producer.getNumOutputs() != 1) 92 return false; 93 94 // Finally the index_map for the result must be invertible. For now just 95 // verify it is a permutation. 96 AffineMap producerResultIndexMap = 97 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 98 if (!producerResultIndexMap.isPermutation()) 99 return false; 100 101 // Ensure that the fusion does not remove size information required to 102 // get the loop bounds. For non-reduction generics, this is trivially the 103 // case due to the output operand. For reductions, we need to check that after 104 // the fusion, each loop dimension has at least one input that defines it. 105 if ((consumer.getNumReductionLoops())) { 106 BitVector coveredDims(consumer.getNumLoops(), false); 107 108 auto addToCoveredDims = [&](AffineMap map) { 109 for (auto result : map.getResults()) 110 if (auto dimExpr = result.dyn_cast<AffineDimExpr>()) 111 coveredDims[dimExpr.getPosition()] = true; 112 }; 113 114 for (auto pair : 115 llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) { 116 Value operand = std::get<0>(pair); 117 if (operand == consumerOpOperand->get()) 118 continue; 119 AffineMap operandMap = std::get<1>(pair); 120 addToCoveredDims(operandMap); 121 } 122 123 for (OpOperand *operand : producer.getInputOperands()) { 124 AffineMap newIndexingMap = 125 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 126 operand, producerResultIndexMap, consumerIndexMap); 127 addToCoveredDims(newIndexingMap); 128 } 129 if (!coveredDims.all()) 130 return false; 131 } 132 133 return true; 134 } 135 136 /// Generate the region of the fused tensor operation. The region of the fused 137 /// op must be empty. 138 static void 139 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, 140 AffineMap consumerToProducerLoopsMap, 141 OpOperand *consumerOpOperand, 142 unsigned nloops) { 143 auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp()); 144 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 145 // Build the region of the fused op. 146 Block &producerBlock = producer->getRegion(0).front(); 147 Block &consumerBlock = consumer->getRegion(0).front(); 148 Block *fusedBlock = new Block(); 149 fusedOp.region().push_back(fusedBlock); 150 BlockAndValueMapping mapper; 151 OpBuilder::InsertionGuard guard(rewriter); 152 rewriter.setInsertionPointToStart(fusedBlock); 153 154 // 2. Add an index operation for every fused loop dimension and use the 155 // `consumerToProducerLoopsMap` to map the producer indices. 156 if (producer.hasIndexSemantics()) { 157 // Add an index operation for every fused loop dimension. 158 unsigned numFusedOpLoops = 159 std::max(producer.getNumLoops(), consumer.getNumLoops()); 160 SmallVector<Value> fusedIndices; 161 fusedIndices.reserve(numFusedOpLoops); 162 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops), 163 std::back_inserter(fusedIndices), [&](uint64_t dim) { 164 return rewriter.create<IndexOp>(producer.getLoc(), dim); 165 }); 166 for (IndexOp indexOp : 167 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) { 168 Value newIndex = rewriter.create<mlir::AffineApplyOp>( 169 producer.getLoc(), 170 consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices); 171 mapper.map(indexOp.getResult(), newIndex); 172 } 173 } 174 // TODO: allow fusing the producer of an output operand. 175 assert(consumer.isInputTensor(consumerOpOperand) && 176 "expected producer of input operand"); 177 // 3. Consumer input operands up to consumerIdx (exclusive). 178 for (BlockArgument bbArg : consumerBlock.getArguments().take_front( 179 consumerOpOperand->getOperandNumber())) // input assumption. 180 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 181 182 // Replacing consumerIdx requires getting the cloned, yielded, value from 183 // the (cloned) producer block. This happens in step 9. 184 185 // 4. Splice in producer's input operands. 186 for (BlockArgument bbArg : 187 producerBlock.getArguments().take_front(producer.getNumInputs())) 188 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 189 190 // 4.b. Producer output operand/map that is fused needs to be mapped to the 191 // producer bbArg if it is an "initTensor" (i.e. its value is actually read). 192 assert(producer->getNumResults() == 1 && "expected single result producer"); 193 if (producer.isInitTensor(producer.getOutputOperand(0))) { 194 BlockArgument bbArg = producerBlock.getArguments() 195 .drop_front(producer.getNumInputs()) 196 // TODO: bbArg index of 197 .front(); 198 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 199 } 200 // 5. Remaining consumer's input operands (drop past index `consumerIdx`). 201 for (BlockArgument bbArg : 202 consumerBlock.getArguments() 203 .take_front(consumer.getNumInputs()) 204 .drop_front(consumerOpOperand->getOperandNumber() + 1)) 205 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 206 // 6. All of consumer's output operands. 207 for (BlockArgument bbArg : 208 consumerBlock.getArguments().take_back(consumer.getNumOutputs())) 209 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 210 // 7. All of producer's output operands except the one fused. 211 // TODO: allow fusion of multi-result producers. 212 assert(producer->getNumResults() == 1 && "expected single result producer"); 213 214 // 8. Clone all producer operations except for the yield and index operations 215 // to the fused operation. 216 for (auto &op : producerBlock.without_terminator()) { 217 if (!isa<IndexOp>(op)) 218 rewriter.clone(op, mapper); 219 } 220 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just 221 // forward the yield operand. 222 auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator()); 223 // TODO: allow fusion of multi-result producers. 224 assert(producer->getNumResults() == 1 && "expected single result producer"); 225 unsigned producerResultNumber = 0; 226 Value replacement = 227 mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); 228 // Sanity checks, if replacement is not already in the mapper then it must be 229 // produced outside. 230 if (replacement == yieldOp.getOperand(producerResultNumber)) { 231 if (auto bb = replacement.dyn_cast<BlockArgument>()) 232 assert(bb.getOwner() != &producerBlock && 233 "yielded block argument must have been mapped"); 234 else 235 assert(!producer->isAncestor(replacement.getDefiningOp()) && 236 "yielded value must have been mapped"); 237 } 238 mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), 239 replacement); 240 // 10. Clone operations from the consumer to the fused op. 241 for (auto &op : consumerBlock.getOperations()) 242 rewriter.clone(op, mapper); 243 244 // Sanity checks. 245 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && 246 "Ill-formed GenericOp region"); 247 } 248 249 static Optional<SmallVector<Value>> 250 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, 251 const ControlFusionFn &controlFn, 252 PatternRewriter &rewriter) { 253 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 254 if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || 255 !controlFn(producer->getResult(0), *consumerOpOperand)) 256 return llvm::None; 257 258 // TODO: allow fusing the producer of an output operand. 259 assert(consumer.isInputTensor(consumerOpOperand) && 260 "expected producer of input operand"); 261 262 // Compute the fused operands list and indexing maps. 263 SmallVector<Value> fusedOperands; 264 SmallVector<AffineMap> fusedIndexMaps; 265 fusedOperands.reserve(producer->getNumOperands() + 266 consumer->getNumOperands()); 267 fusedIndexMaps.reserve(producer->getNumOperands() + 268 consumer->getNumOperands()); 269 // In the following, numbering matches that of `generateFusedTensorOpRegion`. 270 // 3. Consumer input operands/maps up to consumerIdx (exclusive). 271 SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands(); 272 SmallVector<OpOperand *>::iterator it = 273 llvm::find(consumerInputs, consumerOpOperand); 274 assert(it != consumerInputs.end() && "expected to find the consumer operand"); 275 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { 276 fusedOperands.push_back(opOperand->get()); 277 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 278 } 279 // 4. Splice in producer's input operands/maps. 280 assert(producer->getNumResults() == 1 && "expected single result producer"); 281 AffineMap producerResultIndexMap = 282 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 283 for (OpOperand *opOperand : producer.getInputOperands()) { 284 fusedOperands.push_back(opOperand->get()); 285 // Compute indexing maps for the producer args in the fused operation. 286 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 287 opOperand, producerResultIndexMap, 288 consumer.getTiedIndexingMap(consumerOpOperand)); 289 fusedIndexMaps.push_back(map); 290 } 291 // 4.b. Producer output operand/map that is fused needs to be passed if it is 292 // an "initTensor" (i.e. its value is actually read). 293 assert(producer->getNumResults() == 1 && "expected single result producer"); 294 if (producer.isInitTensor(producer.getOutputOperand(0))) { 295 fusedOperands.push_back(producer.getOutputOperand(0)->get()); 296 // Compute indexing maps for the producer args in the fused operation. 297 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 298 producer.getOutputOperand(0), producerResultIndexMap, 299 consumer.getTiedIndexingMap(consumerOpOperand)); 300 fusedIndexMaps.push_back(map); 301 } 302 // 5. Remaining consumer's input operands/maps (drop past index 303 // `consumerIdx`). 304 for (OpOperand *opOperand : 305 llvm::make_range(std::next(it), consumerInputs.end())) { 306 fusedOperands.push_back(opOperand->get()); 307 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 308 } 309 // 6. All of consumer's output operands (skip operands: added by the builder). 310 for (OpOperand *opOperand : consumer.getOutputOperands()) 311 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 312 // 7. All of producer's output operands/maps except the one fused. 313 // TODO: allow fusion of multi-result producers. 314 assert(producer->getNumResults() == 1 && "expected single result producer"); 315 316 // Generate the fused op. 317 SmallVector<Value> consumerOutputs = consumer.getOutputOperands(); 318 auto fusedOp = rewriter.create<GenericOp>( 319 consumer.getLoc(), consumer->getResultTypes(), 320 /*inputs=*/fusedOperands, 321 // TODO: handle outputs. 322 consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 323 consumer.iterator_types(), 324 /*doc=*/nullptr, 325 /*library_call=*/nullptr); 326 if (!fusedOp.getShapesToLoopsMap()) { 327 // Fused op has invalid indexing maps. Typically this means something is off 328 // in the input, but going ahead here would result in verification errors. 329 // So cleanup and abort. 330 rewriter.eraseOp(fusedOp); 331 return llvm::None; 332 } 333 334 // Construct an AffineMap from consumer loops to producer loops. 335 // consumer loop -> tensor index 336 AffineMap consumerResultIndexMap = 337 consumer.getTiedIndexingMap(consumerOpOperand); 338 // tensor index -> producer loop 339 AffineMap invProducerResultIndexMap = 340 inversePermutation(producerResultIndexMap); 341 assert(invProducerResultIndexMap && 342 "expected producer result indexig map to be invertible"); 343 // consumer loop -> producer loop 344 AffineMap consumerToProducerLoopsMap = 345 invProducerResultIndexMap.compose(consumerResultIndexMap); 346 347 generateFusedElementwiseOpRegion(rewriter, fusedOp, 348 consumerToProducerLoopsMap, 349 consumerOpOperand, consumer.getNumLoops()); 350 return SmallVector<Value>(fusedOp->getResults()); 351 } 352 353 static Optional<SmallVector<Value>> 354 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, 355 GenericOp producer, const ControlFusionFn &controlFn) { 356 if (producer->getNumResults() != 1) 357 return llvm::None; 358 359 return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, 360 rewriter); 361 } 362 363 namespace { 364 /// Patterns to fuse a generic op, with the producer of its operands. 365 class FuseElementwiseOps : public OpRewritePattern<GenericOp> { 366 public: 367 FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun, 368 PatternBenefit benefit = 1) 369 : OpRewritePattern<GenericOp>(context, benefit), 370 controlFn(std::move(fun)) {} 371 372 LogicalResult matchAndRewrite(GenericOp genericOp, 373 PatternRewriter &rewriter) const override { 374 // Find the first operand that is defined by another generic op on tensors. 375 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 376 auto producer = 377 dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp()); 378 if (!producer || !producer.hasTensorSemantics()) 379 continue; 380 Optional<SmallVector<Value>> fusedOpResults = 381 fuseElementwiseOps(rewriter, opOperand, producer, controlFn); 382 if (fusedOpResults) { 383 rewriter.replaceOp(genericOp, *fusedOpResults); 384 return success(); 385 } 386 } 387 return failure(); 388 } 389 390 private: 391 ControlFusionFn controlFn; 392 }; 393 } // namespace 394 395 //===---------------------------------------------------------------------===// 396 // Methods and patterns that fuse reshape ops with elementwise operations by 397 // linearization of indexing maps. 398 //===---------------------------------------------------------------------===// 399 400 // TODO(ravishankarm): The indexing maps 401 // these produce in the general case are detrimental to transformations. 402 // These patterns are on deprecation path in favor of using fusion by 403 // collapsing, which covers the only legitimate use case of this pattern of 404 // folding unit-extent dims. 405 406 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` 407 /// provided, given the shape of the source tensor that corresponds to the 408 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions 409 /// are "row-major" ordered logically. 410 /// 411 /// For example: 412 /// 413 /// %0 = op ... : tensor<?x?x4x5xf32> 414 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` 415 /// 416 /// and reshape: 417 /// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] : 418 /// tensor<?x?x4x5xf32> into tensor<?x?xf32> 419 /// 420 /// would be rewritten into: 421 /// %0 = op ... : tensor<?x?x4x5xf32> 422 /// with output index_map 423 /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` 424 template <typename TensorReshapeOp> 425 static AffineMap linearizeCollapsedDims(AffineMap sourceMap, 426 TensorReshapeOp reshapeOp) { 427 constexpr bool isExpanding = 428 std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value; 429 ArrayRef<int64_t> sourceShape = 430 (isExpanding ? reshapeOp.getResultType().getShape() 431 : reshapeOp.getSrcType().getShape()); 432 SmallVector<AffineExpr> resultExprs; 433 ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults(); 434 MLIRContext *context = sourceMap.getContext(); 435 436 // Compute the result exprs based on the reassociation maps. 437 for (auto &indices : reshapeOp.getReassociationIndices()) { 438 // Assume that they are in-order and contiguous (already checked in 439 // verifier). 440 assert(!indices.empty()); 441 SmallVector<int64_t> sizes; 442 SmallVector<AffineExpr> dimExprs; 443 for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()), 444 sourceExprs.slice(indices[0], indices.size()))) { 445 if (std::get<0>(en) == 1) 446 continue; 447 sizes.push_back(std::get<0>(en)); 448 dimExprs.push_back(std::get<1>(en)); 449 } 450 AffineExpr linearizedExpr = 451 makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); 452 resultExprs.push_back(linearizedExpr); 453 } 454 // The new affine map cannot drop unused dimension but some new symbols may 455 // have been added. Create a map with at least as many dimensions/symbols as 456 // the original affine map. 457 int64_t maxDim = -1; 458 int64_t maxSym = -1; 459 getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym); 460 unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims()); 461 unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols()); 462 return AffineMap::get(numDims, numSyms, resultExprs, context); 463 } 464 465 // tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a 466 // producer). Fusing when operand has higher rank will require use of mods and 467 // divs in the indexing maps of the fused op which would make it non-invertible. 468 static bool isTensorReshapeOpFoldableByLinearization( 469 tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) { 470 if (!asProducer) 471 return false; 472 return useIndexMap.isPermutation(); 473 } 474 475 // tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a 476 // consumer). 477 static bool 478 isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp, 479 AffineMap useIndexMap, 480 bool asProducer) { 481 if (asProducer) 482 return false; 483 return useIndexMap.isPermutation(); 484 } 485 486 /// Check if the reshape operation is only expansion into/collapsing of 487 /// unit-dimension. 488 template <typename TensorReshapeOp> 489 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) { 490 constexpr bool isExpanding = 491 std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value; 492 ArrayRef<int64_t> expandedShape = 493 (isExpanding ? reshapeOp.getResultType().getShape() 494 : reshapeOp.getSrcType().getShape()); 495 for (auto &indices : reshapeOp.getReassociationIndices()) { 496 unsigned numUnitDims = 0; 497 for (int64_t position : indices) 498 if (expandedShape[position] == 1) 499 numUnitDims++; 500 if (numUnitDims != indices.size() - 1) 501 return false; 502 } 503 return true; 504 } 505 506 namespace { 507 /// Pattern to fold tensor_expand_shape op with its consumer by using the source 508 /// of the reshape op as the operand in the consumer (instead of the result of 509 /// the tensor_collapse_shape). The corresponding index map in the consumer 510 /// needs to be modified to linearize the folded dimension. 511 /// 512 /// For example, 513 /// 514 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 515 /// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] 516 /// tensor<?x?x?xf32> into tensor<?x?x4x?xf32> 517 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } 518 /// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ... 519 /// -> tensor<?x?x4x?xf32> 520 /// 521 /// can be folded into 522 /// 523 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> 524 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 525 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } 526 /// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ... 527 /// -> tensor<?x?x4x?xf32> 528 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 529 struct FoldProducerReshapeOpByLinearization 530 : public OpRewritePattern<GenericOp> { 531 using OpRewritePattern<GenericOp>::OpRewritePattern; 532 533 LogicalResult matchAndRewrite(GenericOp genericOp, 534 PatternRewriter &rewriter) const override { 535 if (!genericOp.hasTensorSemantics()) 536 return failure(); 537 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 538 for (const auto &en : llvm::enumerate(inputOperands)) { 539 auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>(); 540 if (!reshapeOp) 541 continue; 542 543 if (!isTensorReshapeOpFoldableByLinearization( 544 reshapeOp, genericOp.getTiedIndexingMap(en.value()), 545 /*asProducer =*/true) || 546 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 547 continue; 548 549 // Compute the fused operands list, 550 SmallVector<Value> fusedOperands = genericOp.getInputOperands(); 551 fusedOperands[en.index()] = reshapeOp.src(); 552 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 553 llvm::append_range(fusedOperands, outputOperands); 554 555 // Compute indexing_maps for the fused operation. The indexing_maps for 556 // the operands of the consumers that arent fused are the same. 557 SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps(); 558 559 // Compute the indexing map to use for the result of the producer. 560 AffineMap modifiedMap = 561 linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp); 562 // The modified map cannot have symbols. 563 if (modifiedMap.getNumSymbols()) 564 return failure(); 565 for (AffineExpr expr : modifiedMap.getResults()) { 566 if (!expr.isPureAffine()) 567 return failure(); 568 } 569 fusedIndexMaps[en.index()] = modifiedMap; 570 571 // Further check that the resulting index maps can be fused and 572 // inverted. Without this the resultant op is not legal. 573 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 574 return rewriter.notifyMatchFailure( 575 genericOp, "fused op loop bound computation failed"); 576 } 577 578 rewriter.startRootUpdate(genericOp); 579 genericOp->setOperands(fusedOperands); 580 genericOp.indexing_mapsAttr( 581 rewriter.getAffineMapArrayAttr(fusedIndexMaps)); 582 rewriter.finalizeRootUpdate(genericOp); 583 return success(); 584 } 585 return failure(); 586 } 587 }; 588 589 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its 590 /// producer. The corresponding index map in the consumer needs to be modified 591 /// to linearize the folded dimension. 592 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 593 struct FoldConsumerReshapeOpByLinearization 594 : public OpRewritePattern<TensorReshapeOp> { 595 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 596 597 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 598 PatternRewriter &rewriter) const override { 599 GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>(); 600 if (!producer || !producer.hasTensorSemantics() || 601 producer.getNumOutputs() != 1 || 602 !isTensorReshapeOpFoldableByLinearization( 603 reshapeOp, 604 producer.getTiedIndexingMap(producer.getOutputOperand(0)), 605 /*asProducer =*/false) || 606 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 607 return failure(); 608 // The indexing_maps for the operands of the fused operation are same as 609 // those for the operands of the producer. 610 SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps(); 611 612 // Compute the indexing map to use for the operand of the producer. 613 AffineMap modifiedMap = linearizeCollapsedDims( 614 producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp); 615 for (AffineExpr expr : modifiedMap.getResults()) { 616 if (!expr.isPureAffine()) { 617 return rewriter.notifyMatchFailure( 618 producer, "fused op indexing map is not affine"); 619 } 620 } 621 fusedIndexMaps.back() = modifiedMap; 622 623 // Further check that the resulting index maps can be fused and 624 // inverted. Without this the resultant op is not legal. 625 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 626 return rewriter.notifyMatchFailure( 627 producer, "fused op loop bound computation failed"); 628 } 629 630 Location loc = producer.getLoc(); 631 SmallVector<Value> inputOperands = producer.getInputOperands(); 632 Value output = rewriter.create<TensorReshapeOp>( 633 loc, producer.getOutputOperand(0)->get(), 634 reshapeOp.getReassociationExprs()); 635 auto fusedOp = rewriter.create<GenericOp>( 636 loc, reshapeOp.getResultType(), 637 /*inputs=*/inputOperands, 638 // TODO: handle outputs. 639 /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 640 producer.iterator_types(), 641 /*doc=*/nullptr, 642 /*library_call=*/nullptr); 643 auto &fusedRegion = fusedOp->getRegion(0); 644 rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion, 645 fusedRegion.begin()); 646 rewriter.replaceOp(reshapeOp, fusedOp->getResults()); 647 return success(); 648 } 649 }; 650 } // namespace 651 652 //===---------------------------------------------------------------------===// 653 // Methods and patterns that fuse reshape ops with elementwise operations by 654 // expanding the dimensionality of the elementwise operations. 655 //===---------------------------------------------------------------------===// 656 657 /// Conditions for folding a generic operation with a reshape op by expanding 658 /// the iteration space dimensionality for tensor operations. These are 659 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the 660 /// following fusion pattern. 661 /// 662 /// Consider 663 /// 664 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>) 665 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 666 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 667 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] 668 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] 669 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 670 /// 671 /// The reshape can be folded into the `genericOp` if its loop dimensionality 672 /// is increased to match the result (operand) of the tensor_expand_shape. 673 /// The indexing_map of the fused tensor in the `genericOp` and the 674 /// reassociation map helps compute the indexing maps of the modified op. 675 /// For the above example, based on the reassociation map it 676 /// can be concluded that 677 /// 678 /// - The loop used to access the first dimension of the fused tensor is split 679 /// into two. 680 /// - The loop used to access the second dimension of the fused tensor is kept 681 /// as is. 682 /// - The loop used to access the third dimension of the fused tensor is split 683 /// into three. 684 /// 685 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified 686 /// op, then 687 /// 688 /// d0 -> e0, e1 689 /// d1 -> e2, e3, e4 690 /// d2 -> e5 691 /// 692 /// substituting this, the generic op can be rewritten as 693 /// 694 /// %d = linalg.generic ins(%0, %1 : ) 695 /// indexing_maps = 696 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, 697 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, 698 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] 699 /// 700 /// Since operands to the linalg generic are now 5D, reshapes can be introduced 701 /// to make it consistent 702 /// 703 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]] 704 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 705 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]] 706 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 707 /// 708 /// The added reshapes are again expanding patterns, so they will get fused 709 /// with its producers if possible. 710 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, 711 OpOperand *fusableOpOperand) { 712 // Is fusable only if: 713 // - All the indexing maps for operands and results are projected 714 // permutations. 715 // - The fused tensor is not a scalar. 716 // - All the loops are parallel loops. 717 return genericOp.hasTensorSemantics() && 718 llvm::all_of(genericOp.indexing_maps().getValue(), 719 [](Attribute attr) { 720 return attr.cast<AffineMapAttr>() 721 .getValue() 722 .isProjectedPermutation(); 723 }) && 724 genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && 725 llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { 726 return attr.cast<StringAttr>().getValue() == 727 getParallelIteratorTypeName(); 728 }); 729 } 730 731 namespace { 732 /// Information needed to expand a generic operation to fold the reshape with 733 /// it. 734 class ExpansionInfo { 735 public: 736 // Computes the mapping from original dimensions of the op to the dimensions 737 // of the expanded op given the `indexingMap` of the fused operand/result of 738 // the generic op, the `reassocationMaps` of the reshape op and the shape of 739 // the expanded op. 740 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, 741 ArrayRef<AffineMap> reassociationMaps, 742 ArrayRef<int64_t> expandedShape, 743 ArrayRef<int64_t> collapsedShape, 744 PatternRewriter &rewriter); 745 unsigned getOrigOpNumDims() const { return reassociation.size(); } 746 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } 747 ReassociationIndicesRef getExpandedDims(unsigned i) const { 748 return reassociation[i]; 749 } 750 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { 751 return expandedShapeMap[i]; 752 } 753 ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; } 754 755 private: 756 /// Reassociation from the dimensions in the original operation to the 757 /// dimension of the expanded operation. 758 SmallVector<ReassociationIndices> reassociation; 759 /// Mapping from extent of loops in the original operation, to the extent of 760 /// loops in the expanded operation. 761 SmallVector<SmallVector<int64_t>> expandedShapeMap; 762 /// Extent of the loop in the original operation. 763 SmallVector<int64_t> originalLoopExtent; 764 unsigned expandedOpNumDims; 765 }; 766 } // namespace 767 768 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, 769 OpOperand *fusableOpOperand, 770 ArrayRef<AffineMap> reassociationMaps, 771 ArrayRef<int64_t> expandedShape, 772 ArrayRef<int64_t> collapsedShape, 773 PatternRewriter &rewriter) { 774 if (reassociationMaps.empty()) 775 return failure(); 776 AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); 777 778 Optional<SmallVector<int64_t, 4>> originalLoopRange = 779 linalgOp.getStaticLoopRanges(); 780 if (!originalLoopRange) 781 return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range"); 782 originalLoopExtent.assign(originalLoopRange->begin(), 783 originalLoopRange->end()); 784 785 reassociation.clear(); 786 expandedShapeMap.clear(); 787 // Compute the number of dimension in the expanded op that correspond to each 788 // dimension of the original op. 789 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1); 790 expandedShapeMap.resize(fusedIndexMap.getNumDims()); 791 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { 792 unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition(); 793 AffineMap foldedDims = reassociationMaps[resultExpr.index()]; 794 numExpandedDims[pos] = foldedDims.getNumResults(); 795 ArrayRef<int64_t> shape = 796 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); 797 expandedShapeMap[pos].assign(shape.begin(), shape.end()); 798 } 799 // The remaining dimensions remain the same. 800 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims())) 801 if (expandedShapeMap[i].empty()) 802 expandedShapeMap[i] = {originalLoopExtent[i]}; 803 804 // Compute reassociation map from the original op to the expanded op. 805 unsigned sum = 0; 806 reassociation.reserve(fusedIndexMap.getNumDims()); 807 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) { 808 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value()); 809 reassociation.emplace_back(seq.begin(), seq.end()); 810 sum += numFoldedDim.value(); 811 } 812 expandedOpNumDims = sum; 813 return success(); 814 } 815 816 /// Epanding the body of a linalg operation requires adaptations of the accessed 817 /// loop indices. Specifically, access of indices in the original operation need 818 /// to be replaced with linearizations of indices in the expanded op. That 819 /// requires the shape of the expanded dimensions to be static (at least all but 820 /// the most significant). For now check that these are all statically sized. 821 /// Note that this could be extended to handle dynamic case, but the 822 /// implementation below uses `affine.apply` which seems to have issues when the 823 /// shapes are not static. 824 static LogicalResult isGenericOpExpandable(GenericOp genericOp, 825 const ExpansionInfo &expansionInfo, 826 PatternRewriter &rewriter) { 827 if (!genericOp.hasIndexSemantics()) 828 return success(); 829 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 830 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 831 if (expandedShape.size() == 1) 832 continue; 833 for (int64_t shape : expandedShape.drop_front()) { 834 if (ShapedType::isDynamic(shape)) { 835 return rewriter.notifyMatchFailure( 836 genericOp, "cannot expand due to index semantics and dynamic dims"); 837 } 838 } 839 } 840 return success(); 841 } 842 843 /// Return the indexing map to use in the expanded op for a given the 844 /// `indexingMap` of the original operation. 845 static AffineMap 846 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, 847 const ExpansionInfo &expansionInfo) { 848 SmallVector<AffineExpr> newExprs; 849 for (AffineExpr expr : indexingMap.getResults()) { 850 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 851 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>( 852 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { 853 return builder.getAffineDimExpr(static_cast<unsigned>(v)); 854 })); 855 newExprs.append(expandedExprs.begin(), expandedExprs.end()); 856 } 857 return AffineMap::get(expansionInfo.getExpandedOpNumDims(), 858 indexingMap.getNumSymbols(), newExprs, 859 builder.getContext()); 860 } 861 862 /// Return the type of the operand/result to use in the expanded op given the 863 /// type in the original op. 864 static RankedTensorType getExpandedType(RankedTensorType originalType, 865 AffineMap indexingMap, 866 const ExpansionInfo &expansionInfo) { 867 SmallVector<int64_t> expandedShape; 868 for (AffineExpr expr : indexingMap.getResults()) { 869 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 870 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); 871 expandedShape.append(dimExpansion.begin(), dimExpansion.end()); 872 } 873 return RankedTensorType::get(expandedShape, originalType.getElementType()); 874 } 875 876 /// Returns the reassociation maps to use in the `tensor.expand_shape` 877 /// operation to convert the operands of the original operation to operands of 878 /// the expanded operation. The same method is used to compute the 879 /// `tensor.collapse_shape` used to collapse the result of the expanded 880 /// op to get the value that can replace all uses of the results of the original 881 /// op. 882 static SmallVector<ReassociationIndices> 883 getReassociationForExpansion(AffineMap indexingMap, 884 const ExpansionInfo &expansionInfo) { 885 SmallVector<ReassociationIndices> reassociation; 886 unsigned numReshapeDims = 0; 887 for (AffineExpr expr : indexingMap.getResults()) { 888 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 889 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); 890 SmallVector<int64_t, 2> indices = llvm::to_vector<2>( 891 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims)); 892 reassociation.emplace_back(std::move(indices)); 893 numReshapeDims += numExpandedDims; 894 } 895 return reassociation; 896 } 897 898 /// Update the body of an expanded linalg operation having index semantics. The 899 /// indices of the original operation need to be recovered by linearizing the 900 /// indices of the correspoding dimensions of the expanded operation. For now it 901 /// is assumed that the shapes of the expanded operation needed for 902 /// linearization are static. 903 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, 904 Location loc, Region &fusedRegion, 905 const ExpansionInfo &expansionInfo) { 906 // Replace the original indices by the linearization of the expanded indices. 907 for (IndexOp indexOp : 908 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) { 909 ArrayRef<int64_t> expandedDims = 910 expansionInfo.getExpandedDims(indexOp.dim()); 911 assert(!expandedDims.empty() && "expected valid expansion info"); 912 913 // Skip index operations that are not affected by the expansion. 914 if (expandedDims.size() == 1 && 915 expandedDims.front() == (int64_t)indexOp.dim()) 916 continue; 917 918 // Linearize the expanded indices of the original index dimension. 919 OpBuilder::InsertionGuard guard(rewriter); 920 rewriter.setInsertionPointAfter(indexOp); 921 ArrayRef<int64_t> expandedDimsShape = 922 expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front(); 923 SmallVector<Value> expandedIndices; 924 expandedIndices.reserve(expandedDims.size() - 1); 925 llvm::transform( 926 expandedDims.drop_front(), std::back_inserter(expandedIndices), 927 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); 928 Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front()); 929 for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { 930 assert(!ShapedType::isDynamic(std::get<0>(it))); 931 AffineExpr idx, acc; 932 bindDims(rewriter.getContext(), idx, acc); 933 newIndex = rewriter.create<AffineApplyOp>( 934 indexOp.getLoc(), idx + acc * std::get<0>(it), 935 ValueRange{std::get<1>(it), newIndex}); 936 } 937 rewriter.replaceOp(indexOp, newIndex); 938 } 939 } 940 941 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op 942 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes 943 /// that those conditions have been satisfied. 944 static Optional<SmallVector<Value>> 945 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, 946 OpOperand *fusableOpOperand, 947 PatternRewriter &rewriter) { 948 assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) && 949 "preconditions for fuse operation failed"); 950 // Check if reshape is expanding or collapsing. 951 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp); 952 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp); 953 bool isExpanding = (expandingReshapeOp != nullptr); 954 RankedTensorType expandedType = isExpanding 955 ? expandingReshapeOp.getResultType() 956 : collapsingReshapeOp.getSrcType(); 957 RankedTensorType collapsedType = isExpanding 958 ? expandingReshapeOp.getSrcType() 959 : collapsingReshapeOp.getResultType(); 960 961 ExpansionInfo expansionInfo; 962 if (failed(expansionInfo.compute( 963 genericOp, fusableOpOperand, 964 isExpanding ? expandingReshapeOp.getReassociationMaps() 965 : collapsingReshapeOp.getReassociationMaps(), 966 expandedType.getShape(), collapsedType.getShape(), rewriter))) 967 return llvm::None; 968 969 if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) 970 return llvm::None; 971 972 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( 973 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) { 974 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); 975 })); 976 977 SmallVector<Value> expandedOpOperands; 978 expandedOpOperands.reserve(genericOp.getNumInputs()); 979 for (OpOperand *opOperand : genericOp.getInputOperands()) { 980 if (opOperand == fusableOpOperand) { 981 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() 982 : collapsingReshapeOp.src()); 983 continue; 984 } 985 if (genericOp.isInputTensor(opOperand)) { 986 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 987 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 988 RankedTensorType expandedOperandType = 989 getExpandedType(opOperandType, indexingMap, expansionInfo); 990 if (expandedOperandType != opOperand->get().getType()) { 991 // Reshape the operand to get the right type. 992 SmallVector<ReassociationIndices> reassociation = 993 getReassociationForExpansion(indexingMap, expansionInfo); 994 if (failed(reshapeLikeShapesAreCompatible( 995 [&](const Twine &msg) { 996 return rewriter.notifyMatchFailure(genericOp, msg); 997 }, 998 opOperandType.getShape(), expandedOperandType.getShape(), 999 reassociation, 1000 /*isExpandingReshape=*/true))) 1001 return llvm::None; 1002 expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>( 1003 genericOp.getLoc(), expandedOperandType, opOperand->get(), 1004 reassociation)); 1005 continue; 1006 } 1007 } 1008 expandedOpOperands.push_back(opOperand->get()); 1009 } 1010 1011 Location loc = genericOp.getLoc(); 1012 SmallVector<Value> outputs; 1013 for (OpOperand *opOperand : genericOp.getOutputOperands()) { 1014 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 1015 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 1016 RankedTensorType expandedOutputType = 1017 getExpandedType(opOperandType, indexingMap, expansionInfo); 1018 if (expandedOutputType != opOperand->get().getType()) { 1019 SmallVector<ReassociationIndices> reassociation = 1020 getReassociationForExpansion(indexingMap, expansionInfo); 1021 if (failed(reshapeLikeShapesAreCompatible( 1022 [&](const Twine &msg) { 1023 return rewriter.notifyMatchFailure(genericOp, msg); 1024 }, 1025 opOperandType.getShape(), expandedOutputType.getShape(), 1026 reassociation, 1027 /*isExpandingReshape=*/true))) 1028 return llvm::None; 1029 outputs.push_back(rewriter.create<tensor::ExpandShapeOp>( 1030 genericOp.getLoc(), expandedOutputType, opOperand->get(), 1031 reassociation)); 1032 } 1033 } 1034 1035 // The iterator types of the expanded op are all parallel. 1036 SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(), 1037 getParallelIteratorTypeName()); 1038 1039 TypeRange resultTypes = ValueRange(outputs).getTypes(); 1040 auto fusedOp = 1041 rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes, 1042 /*inputs=*/expandedOpOperands, outputs, 1043 expandedOpIndexingMaps, iteratorTypes); 1044 Region &fusedRegion = fusedOp->getRegion(0); 1045 Region &originalRegion = genericOp->getRegion(0); 1046 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); 1047 1048 // Update the index accesses after the expansion. 1049 updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); 1050 1051 // Reshape the result values to their original shape if this is a collapsing 1052 // reshape folded into its consumer. 1053 SmallVector<Value> resultVals; 1054 for (OpResult opResult : genericOp->getOpResults()) { 1055 int64_t resultNumber = opResult.getResultNumber(); 1056 if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { 1057 SmallVector<ReassociationIndices> reassociation = 1058 getReassociationForExpansion( 1059 genericOp.getTiedIndexingMap( 1060 genericOp.getOutputOperand(resultNumber)), 1061 expansionInfo); 1062 resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>( 1063 genericOp.getLoc(), opResult.getType(), 1064 fusedOp->getResult(resultNumber), reassociation)); 1065 } else { 1066 resultVals.push_back(fusedOp->getResult(resultNumber)); 1067 } 1068 } 1069 // Assuming a single result. 1070 return resultVals; 1071 } 1072 1073 namespace { 1074 1075 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op, 1076 /// when the reshape op is collapsing dimensions. The dimensionality of the loop 1077 /// in the consumer is expanded. 1078 class FoldWithProducerReshapeOpByExpansion 1079 : public OpRewritePattern<GenericOp> { 1080 public: 1081 FoldWithProducerReshapeOpByExpansion(MLIRContext *context, 1082 ControlFusionFn foldReshapes, 1083 PatternBenefit benefit = 1) 1084 : OpRewritePattern<GenericOp>(context, benefit), 1085 controlFoldingReshapes(std::move(foldReshapes)) {} 1086 1087 LogicalResult matchAndRewrite(GenericOp genericOp, 1088 PatternRewriter &rewriter) const override { 1089 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 1090 tensor::CollapseShapeOp reshapeOp = 1091 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>(); 1092 if (!reshapeOp) 1093 continue; 1094 // Fold only if 1095 // - The tensor reshape op is folding. 1096 // - All constraints of fusing with reshape by expansion are met. 1097 if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || 1098 (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) 1099 continue; 1100 1101 Optional<SmallVector<Value>> replacementValues = 1102 fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); 1103 if (!replacementValues) 1104 return failure(); 1105 rewriter.replaceOp(genericOp, replacementValues.getValue()); 1106 return success(); 1107 } 1108 return failure(); 1109 } 1110 1111 private: 1112 ControlFusionFn controlFoldingReshapes; 1113 }; 1114 1115 /// Pattern to fold a tensor_expand_shape op with its producer generic op 1116 /// by expanding the dimensionality of the loop in the producer op. 1117 struct FoldReshapeWithGenericOpByExpansion 1118 : public OpRewritePattern<tensor::ExpandShapeOp> { 1119 1120 FoldReshapeWithGenericOpByExpansion(MLIRContext *context, 1121 ControlFusionFn foldReshapes, 1122 PatternBenefit benefit = 1) 1123 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), 1124 controlFoldingReshapes(std::move(foldReshapes)) {} 1125 1126 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, 1127 PatternRewriter &rewriter) const override { 1128 // Fold only if all constraints of fusing with reshape by expansion are met. 1129 GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>(); 1130 if (!producer || producer.getNumOutputs() != 1 || 1131 !isFusableWithReshapeByDimExpansion(producer, 1132 producer.getOutputOperand(0)) || 1133 !controlFoldingReshapes(producer->getResult(0), 1134 reshapeOp->getOpOperand(0))) 1135 return failure(); 1136 Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion( 1137 producer, reshapeOp, producer.getOutputOperand(0), rewriter); 1138 if (!replacementValues) 1139 return failure(); 1140 rewriter.replaceOp(reshapeOp, replacementValues.getValue()); 1141 return success(); 1142 } 1143 1144 private: 1145 ControlFusionFn controlFoldingReshapes; 1146 }; 1147 } // namespace 1148 1149 //===---------------------------------------------------------------------===// 1150 // Methods and patterns to fuse reshape with linalg.generic operations by 1151 // contraction of dimensions. 1152 //===---------------------------------------------------------------------===// 1153 1154 /// For a given list of indices in the range of the `indexingMap` that are 1155 /// folded, return the indices of the corresponding domain. Return `llvm::None` 1156 /// on failure. Ensures that all the elements of the returned reassociation are 1157 /// distinct. 1158 static ReassociationIndices 1159 getDomainReassociation(AffineMap indexingMap, 1160 ReassociationIndicesRef rangeReassociation) { 1161 assert(indexingMap.isProjectedPermutation() && 1162 "expected projected permutation"); 1163 1164 ReassociationIndices domainReassociation = llvm::to_vector<4>( 1165 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t { 1166 return indexingMap.getResults()[pos] 1167 .cast<AffineDimExpr>() 1168 .getPosition(); 1169 })); 1170 // The projected permutation semantics ensures that there is no repetition of 1171 // the domain indices. 1172 return domainReassociation; 1173 } 1174 1175 /// For a given `dimSequence`, check if the sequence is conserved in the 1176 /// `indexingMap`. `indexingMap` is expected to be a projected permutation. 1177 /// Non-existence of the sequence returns true as well. 1178 static bool isDimSequencePreserved(AffineMap indexingMap, 1179 ReassociationIndicesRef dimSequence) { 1180 assert(!dimSequence.empty() && 1181 "expected non-empty list for dimension sequence"); 1182 assert(indexingMap.isProjectedPermutation() && 1183 "expected indexing map to be projected permutation"); 1184 1185 llvm::SmallDenseSet<unsigned, 4> sequenceElements; 1186 sequenceElements.insert(dimSequence.begin(), dimSequence.end()); 1187 1188 unsigned dimSequenceStart = dimSequence[0]; 1189 for (const auto &expr : enumerate(indexingMap.getResults())) { 1190 unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition(); 1191 // 1. Check if this start of the sequence. 1192 if (dimInMapStart == dimSequenceStart) { 1193 if (expr.index() + dimSequence.size() > indexingMap.getNumResults()) 1194 return false; 1195 // 1a. Check if sequence is preserved. 1196 for (const auto &dimInSequence : enumerate(dimSequence)) { 1197 unsigned dimInMap = 1198 indexingMap.getResult(expr.index() + dimInSequence.index()) 1199 .cast<AffineDimExpr>() 1200 .getPosition(); 1201 if (dimInMap != dimInSequence.value()) 1202 return false; 1203 } 1204 // Found the sequence. Projected permutation 1205 // enforces that all AffineDimExprs in the result are unique, so no 1206 // further checks are needed. 1207 return true; 1208 } 1209 // 2. If position in the expr (which is of type AffineDimExpr) is part 1210 // of sequence, return false here. This implies the entire sequence does not 1211 // exist in the indexing map. 1212 if (sequenceElements.count(dimInMapStart)) 1213 return false; 1214 } 1215 // 3. No element of sequence found. Return true. 1216 return true; 1217 } 1218 1219 // Return the list of dimensions of the iteration domain that can be 1220 // collapsed to allow for fusion with the a producer that is an expand_shape 1221 // operation. If all dimensions created by expansion can be collapsed in the 1222 // iteration space then the reshape is defunct. 1223 // 1224 // Example: 1225 // 1226 // ```mlir 1227 // #map = affine_map<(d0, d1) -> (d0, d1)> 1228 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 1229 // %2 = linalg.init_tensor [..] : tensor<?x4xf32> 1230 // %3 = linalg.generic { 1231 // indexing_maps = [#map, #map], 1232 // iterator_types = ["parallel" ,"parallel"]} 1233 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. } 1234 // ``` 1235 // 1236 // can be fused by collapsing the dimensions of the iteration space. 1237 // 1238 // ```mlir 1239 // #map = affine_map<(d0) -> (d0)> 1240 // %2 = linalg.init_tensor [..] : tensor<?xf32> 1241 // %3 = linalg.generic { 1242 // indexing_maps = [#map, #map], 1243 // iterator_types = ["parallel"]} 1244 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. } 1245 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 1246 // ``` 1247 // 1248 // In the following example, 1249 // 1250 // ```mlir 1251 // #map0 = affine_map<(d0, d1) -> (d0, d1)> 1252 // #map1 = affine_map<(d0, d1) -> (d1, d0)> 1253 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 1254 // %2 = linalg.init_tensor [..] : tensor<4x?xf32> 1255 // %2 = linalg.generic { 1256 // indexing_maps = [#map0, #map1], 1257 // iterator_types = ["parallel" ,"parallel"]} 1258 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. } 1259 // ``` 1260 // 1261 // the reshape cannot be fused with the generic op by collapsing the op 1262 // dimensions since the indexing maps will have to contain mods and divs 1263 // to preserve the accesses pattern. When no dimensions of the iteration 1264 // space are collapsable and empty vector is returned. 1265 static SmallVector<ReassociationIndices> 1266 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, 1267 ArrayRef<ReassociationIndices> reassociation) { 1268 // Some basic checks for this fusion to be valid. 1269 if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1) 1270 return {}; 1271 1272 if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) { 1273 return map.isProjectedPermutation(); 1274 })) { 1275 return {}; 1276 } 1277 1278 // Compute all the loops with the reduction iterator types. 1279 SmallVector<int64_t> reductionDims; 1280 for (auto iteratorType : llvm::enumerate(genericOp.iterator_types())) { 1281 if (isReductionIterator(iteratorType.value())) { 1282 reductionDims.push_back(iteratorType.index()); 1283 } 1284 } 1285 1286 llvm::SmallDenseSet<unsigned, 4> processedIterationDims; 1287 AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand); 1288 auto iteratorTypes = genericOp.iterator_types().getValue(); 1289 SmallVector<ReassociationIndices> iterationSpaceReassociation; 1290 for (ReassociationIndicesRef foldedRangeDims : reassociation) { 1291 assert(!foldedRangeDims.empty() && "unexpected empty reassociation"); 1292 1293 // Ignore dims that are not folded. 1294 if (foldedRangeDims.size() == 1) 1295 continue; 1296 1297 ReassociationIndices foldedIterationSpaceDims = 1298 getDomainReassociation(indexingMap, foldedRangeDims); 1299 1300 // Check that the folded iteration dims do not contain already processed 1301 // dims. 1302 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) { 1303 return processedIterationDims.count(dim); 1304 })) 1305 continue; 1306 1307 // Check that all folded iterator types are all parallel or all reductions. 1308 Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]]; 1309 if (!isParallelIterator(startIteratorType) && 1310 !isReductionIterator(startIteratorType)) 1311 continue; 1312 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) { 1313 return iteratorTypes[dim] != startIteratorType; 1314 })) 1315 continue; 1316 1317 // If the folded dimensions correspond to a "reduction" iterator type, 1318 // the folded dimensions need to be "in-order". Strictly speaking this is 1319 // not necessary, for reductions that are associative and commutative, but 1320 // using a more strict definition of reduction for now. 1321 if (isReductionIterator(startIteratorType)) { 1322 bool isContiguous = false; 1323 for (auto startDim : llvm::enumerate(reductionDims)) { 1324 // Move window in `reductionDims` to start of the folded iteration dims. 1325 if (startDim.value() != foldedIterationSpaceDims[0]) 1326 continue; 1327 // If sizes doesnt match, trivial not contiguous. This condition should 1328 // not be hit. 1329 if (startDim.index() + foldedIterationSpaceDims.size() > 1330 reductionDims.size()) 1331 break; 1332 // Check that the contiguity is maintained. 1333 isContiguous = true; 1334 for (auto foldedDim : llvm::enumerate(foldedIterationSpaceDims)) { 1335 if (reductionDims[foldedDim.index() + startDim.index()] != 1336 foldedDim.value()) { 1337 isContiguous = false; 1338 break; 1339 } 1340 } 1341 break; 1342 } 1343 if (!isContiguous) 1344 continue; 1345 } 1346 1347 // Check that the sequence is preserved in all indexing maps. 1348 if (llvm::any_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) { 1349 return !isDimSequencePreserved(indexingMap, foldedIterationSpaceDims); 1350 })) 1351 continue; 1352 1353 processedIterationDims.insert(foldedIterationSpaceDims.begin(), 1354 foldedIterationSpaceDims.end()); 1355 iterationSpaceReassociation.emplace_back( 1356 std::move(foldedIterationSpaceDims)); 1357 } 1358 1359 return iterationSpaceReassociation; 1360 } 1361 1362 /// Helper class to carry state while collapsing the `linalg.generic` op. 1363 namespace { 1364 class CollapsingInfo { 1365 public: 1366 LogicalResult initialize(unsigned origNumLoops, 1367 ArrayRef<ReassociationIndices> foldedIterationDims) { 1368 llvm::SmallDenseSet<int64_t, 4> processedDims; 1369 // Find all the dims that are folded. 1370 for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) { 1371 if (foldedIterationDim.empty()) 1372 continue; 1373 // If the folded dims contain dims already folded, that's illegal 1374 // specification. Repetition within a list is also illegal. 1375 for (auto dim : foldedIterationDim) { 1376 if (dim >= origNumLoops) 1377 return failure(); 1378 if (processedDims.count(dim)) 1379 return failure(); 1380 processedDims.insert(dim); 1381 } 1382 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(), 1383 foldedIterationDim.end()); 1384 } 1385 if (processedDims.size() > origNumLoops) 1386 return failure(); 1387 1388 // Add all the preserved dims of the original op as single 1389 // elements to `collapsedOpToOrigOpIterationDim`. 1390 for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) { 1391 if (processedDims.count(dim)) 1392 continue; 1393 collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim}); 1394 } 1395 1396 llvm::sort(collapsedOpToOrigOpIterationDim, 1397 [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) { 1398 return lhs[0] < rhs[0]; 1399 }); 1400 origOpToCollapsedOpIterationDim.resize(origNumLoops); 1401 for (auto foldedDims : llvm::enumerate(collapsedOpToOrigOpIterationDim)) { 1402 for (auto dim : enumerate(foldedDims.value())) 1403 origOpToCollapsedOpIterationDim[dim.value()] = 1404 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index()); 1405 } 1406 return success(); 1407 } 1408 1409 /// Return mapping from collapsed loop domain to original loop domain. 1410 ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const { 1411 return collapsedOpToOrigOpIterationDim; 1412 } 1413 1414 /// Return mapping from original loop domain to collapsed loop domain. The 1415 /// mapping is a pair. First value is the dimension in the collapsed loop that 1416 /// the original loop is mapped to. Second is the relative position in folded 1417 /// list of this domain. For example if the original loop domain is 3D, and 1418 /// the collapsed loop domain is folding all of it, i.e. 1419 /// 1420 /// ``` 1421 /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]` 1422 /// ``` 1423 /// 1424 /// then 1425 /// 1426 /// ``` 1427 /// origOpToCollapsedOpMapping[0] = {0, 0}; 1428 /// origOpToCollapsedOpMapping[1] = {0, 1}; 1429 /// origOpToCollapsedOpMapping[2] = {0, 2}; 1430 /// origOpToCollapsedOpMapping[3] = {1, 0}; 1431 /// origOpToCollapsedOpMapping[4] = {1, 1}; 1432 /// ``` 1433 /// 1434 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const { 1435 return origOpToCollapsedOpIterationDim; 1436 } 1437 1438 /// Return the collapsed op iteration domain rank. 1439 unsigned getCollapsedOpIterationRank() const { 1440 return collapsedOpToOrigOpIterationDim.size(); 1441 } 1442 1443 private: 1444 /// Map from the iteration domain index in collapsed op to the iteration 1445 /// domain indices in the original op. 1446 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim; 1447 1448 /// Map from iteration domain index in the original op to the iteration domain 1449 /// index in the collapsed op. 1450 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim; 1451 }; 1452 } // namespace 1453 1454 /// Get the iterator types for the collapsed operation given the original 1455 /// iterator types and collapsed dimensions. 1456 static SmallVector<StringRef> 1457 getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes, 1458 const CollapsingInfo &collapsingInfo) { 1459 SmallVector<StringRef> collapsedIteratorTypes; 1460 for (ReassociationIndicesRef foldedIterDims : 1461 collapsingInfo.getCollapsedOpToOrigOpMapping()) { 1462 assert(!foldedIterDims.empty() && 1463 "reassociation indices expected to have non-empty sets"); 1464 // Just pick the iterator type of the first folded dim. Pre-condition checks 1465 // expected to have checked that iterator types of all folded dimensions are 1466 // the same. 1467 collapsedIteratorTypes.push_back( 1468 iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue()); 1469 } 1470 return collapsedIteratorTypes; 1471 } 1472 1473 /// Compute the indexing map in the collapsed op that corresponds to the given 1474 /// `indexingMap` of the original operation. 1475 static AffineMap 1476 getCollapsedOpIndexingMap(AffineMap indexingMap, 1477 const CollapsingInfo &collapsingInfo) { 1478 MLIRContext *context = indexingMap.getContext(); 1479 assert(indexingMap.isProjectedPermutation() && 1480 "expected indexing map to be projected permutation"); 1481 SmallVector<AffineExpr> resultExprs; 1482 auto origOpToCollapsedOpMapping = 1483 collapsingInfo.getOrigOpToCollapsedOpMapping(); 1484 for (auto expr : indexingMap.getResults()) { 1485 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 1486 // If the dim is not the first of the collapsed dim, do nothing. 1487 if (origOpToCollapsedOpMapping[dim].second != 0) 1488 continue; 1489 // The next n-dims are guaranteed to be collapsed. So just use the 1490 // iteration dimension of the collapsed op. 1491 resultExprs.push_back( 1492 getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context)); 1493 } 1494 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0, 1495 resultExprs, context); 1496 } 1497 1498 /// Return the `reassociation` indices to use to collapse the operand when the 1499 /// iteration space of a generic op is collapsed. 1500 static SmallVector<ReassociationIndices> 1501 getOperandReassociation(AffineMap indexingMap, 1502 const CollapsingInfo &collapsingInfo) { 1503 unsigned counter = 0; 1504 SmallVector<ReassociationIndices> operandReassociation; 1505 auto origOpToCollapsedOpMapping = 1506 collapsingInfo.getOrigOpToCollapsedOpMapping(); 1507 auto collapsedOpToOrigOpMapping = 1508 collapsingInfo.getCollapsedOpToOrigOpMapping(); 1509 while (counter < indexingMap.getNumResults()) { 1510 unsigned dim = 1511 indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition(); 1512 if (origOpToCollapsedOpMapping[dim].second == 0) { 1513 // This is the start of a collapsed dimensions of the iteration that 1514 // is gauranteed to be preserved in the indexing map. The number of folded 1515 // dims is obtained from the collapsed op to original op mapping. 1516 unsigned numFoldedDims = 1517 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first] 1518 .size(); 1519 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims); 1520 operandReassociation.emplace_back(range.begin(), range.end()); 1521 counter += numFoldedDims; 1522 } 1523 } 1524 return operandReassociation; 1525 } 1526 1527 /// Get the new value to use for a given `OpOperand` in the collapsed operation. 1528 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, 1529 OpOperand *opOperand, 1530 const CollapsingInfo &collapsingInfo, 1531 OpBuilder &builder) { 1532 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 1533 SmallVector<ReassociationIndices> operandReassociation = 1534 getOperandReassociation(indexingMap, collapsingInfo); 1535 1536 // If the number of entries in the reassocation for the operand is same as the 1537 // number of results of the indexing map, then nothing to do for this operand. 1538 Value operand = opOperand->get(); 1539 if (operandReassociation.size() == indexingMap.getNumResults()) 1540 return operand; 1541 1542 // Insert a reshape to collapse the dimensions. 1543 auto reshapeOp = builder.create<tensor::CollapseShapeOp>( 1544 loc, operand, operandReassociation); 1545 return reshapeOp.getResult(); 1546 } 1547 1548 /// Modify the `linalg.index` operations in the original generic op, to its 1549 /// value in the collapsed operation. 1550 void generateCollapsedIndexingRegion(Location loc, Block *block, 1551 const CollapsingInfo &collapsingInfo, 1552 ValueRange loopRange, 1553 PatternRewriter &rewriter) { 1554 OpBuilder::InsertionGuard g(rewriter); 1555 rewriter.setInsertionPointToStart(block); 1556 1557 // Collect all the original index ops. 1558 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>()); 1559 1560 // For each folded dimension list resolve the original induction variable 1561 // values in terms of the folded dimension induction variable. 1562 // i_{folded} = (i_0 * d1 + i1) * d2 + i2. 1563 // can be inverted to 1564 // i2 = i_{folded} % d2 1565 // i1 = (i_{folded} / d2) % d1 1566 // i0 = i_{folded} / (d1 * d2) 1567 llvm::DenseMap<unsigned, Value> indexReplacementVals; 1568 for (auto &foldedDims : 1569 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) { 1570 ReassociationIndicesRef foldedDimsRef(foldedDims.value()); 1571 Value newIndexVal = 1572 rewriter.create<linalg::IndexOp>(loc, foldedDims.index()); 1573 for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) { 1574 indexReplacementVals[dim] = 1575 rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]); 1576 newIndexVal = 1577 rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]); 1578 } 1579 indexReplacementVals[foldedDims.value().front()] = newIndexVal; 1580 } 1581 1582 for (auto indexOp : indexOps) { 1583 auto dim = indexOp.dim(); 1584 rewriter.replaceOp(indexOp, indexReplacementVals[dim]); 1585 } 1586 } 1587 1588 /// Implementation of fusion with reshape operation by collapsing dimensions. 1589 static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims( 1590 GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims, 1591 OpOperand *fusableOpOperand, PatternRewriter &rewriter) { 1592 // Bail on trivial no-op cases. 1593 if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() || 1594 llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { 1595 return foldedDims.size() <= 1; 1596 })) 1597 return failure(); 1598 1599 CollapsingInfo collapsingInfo; 1600 if (failed(collapsingInfo.initialize(genericOp.getNumLoops(), 1601 foldedIterationDims))) { 1602 return rewriter.notifyMatchFailure( 1603 genericOp, "illegal to collapse specified dimensions"); 1604 } 1605 1606 // Get the iterator types for the operand. 1607 SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes( 1608 genericOp.iterator_types().getValue(), collapsingInfo); 1609 1610 // Get the indexing maps. 1611 auto indexingMaps = llvm::to_vector( 1612 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap map) { 1613 return getCollapsedOpIndexingMap(map, collapsingInfo); 1614 })); 1615 1616 Location loc = genericOp->getLoc(); 1617 1618 // Get the input operands. 1619 auto inputOperands = llvm::to_vector( 1620 llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) { 1621 return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo, 1622 rewriter); 1623 })); 1624 1625 // Get the output operands and result types. 1626 SmallVector<Type> resultTypes; 1627 SmallVector<Value> outputOperands; 1628 resultTypes.reserve(genericOp.getNumOutputs()); 1629 outputOperands.reserve(genericOp.getNumOutputs()); 1630 for (OpOperand *output : genericOp.getOutputOperands()) { 1631 Value newOutput = 1632 getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter); 1633 outputOperands.push_back(newOutput); 1634 resultTypes.push_back(newOutput.getType()); 1635 } 1636 1637 // Create the generic op. 1638 auto collapsedGenericOp = rewriter.create<linalg::GenericOp>( 1639 loc, resultTypes, inputOperands, outputOperands, indexingMaps, 1640 iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); 1641 Block *origOpBlock = &genericOp->getRegion(0).front(); 1642 Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front(); 1643 rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, 1644 collapsedOpBlock->getArguments()); 1645 1646 if (collapsedGenericOp.hasIndexSemantics()) { 1647 // Collect the loop range of the generic op. 1648 OpBuilder::InsertionGuard g(rewriter); 1649 rewriter.setInsertionPoint(collapsedGenericOp); 1650 SmallVector<Range> loopRanges = 1651 cast<LinalgOp>(genericOp.getOperation()) 1652 .createLoopRanges(rewriter, genericOp.getLoc()); 1653 assert(llvm::all_of(loopRanges, 1654 [](Range range) { 1655 return matchPattern(range.offset, m_Zero()) && 1656 matchPattern(range.stride, m_One()); 1657 }) && 1658 "expected all loop ranges to have zero start and unit stride"); 1659 SmallVector<Value> loopBound = llvm::to_vector( 1660 llvm::map_range(loopRanges, [](Range range) { return range.size; })); 1661 generateCollapsedIndexingRegion(loc, 1662 &collapsedGenericOp->getRegion(0).front(), 1663 collapsingInfo, loopBound, rewriter); 1664 } 1665 1666 // Insert expanding reshape for the result to get back the original result 1667 // type. 1668 SmallVector<Value> results; 1669 for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) { 1670 Value collapsedOpResult = 1671 collapsedGenericOp->getResult(originalResult.index()); 1672 auto originalResultType = 1673 originalResult.value().getType().cast<ShapedType>(); 1674 auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>(); 1675 if (collapsedOpResultType.getRank() != originalResultType.getRank()) { 1676 AffineMap indexingMap = 1677 genericOp.getTiedIndexingMapForResult(originalResult.value()); 1678 SmallVector<ReassociationIndices> reassociation = 1679 getOperandReassociation(indexingMap, collapsingInfo); 1680 Value result = rewriter.create<tensor::ExpandShapeOp>( 1681 loc, originalResultType, collapsedOpResult, reassociation); 1682 results.push_back(result); 1683 } else { 1684 results.push_back(collapsedOpResult); 1685 } 1686 } 1687 return results; 1688 } 1689 1690 namespace { 1691 1692 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by 1693 /// contracting dimensions of the loop. 1694 class FoldWithProducerReshapeOpByCollapsing 1695 : public OpRewritePattern<GenericOp> { 1696 public: 1697 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context, 1698 ControlFusionFn foldReshapes, 1699 PatternBenefit benefit = 1) 1700 : OpRewritePattern<GenericOp>(context, benefit), 1701 controlFoldingReshapes(std::move(foldReshapes)) {} 1702 1703 LogicalResult matchAndRewrite(GenericOp genericOp, 1704 PatternRewriter &rewriter) const override { 1705 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 1706 tensor::ExpandShapeOp reshapeOp = 1707 opOperand->get().getDefiningOp<tensor::ExpandShapeOp>(); 1708 if (!reshapeOp) 1709 continue; 1710 1711 SmallVector<ReassociationIndices> collapsableIterationDims = 1712 getCollapsableIterationSpaceDims(genericOp, opOperand, 1713 reshapeOp.getReassociationIndices()); 1714 if (collapsableIterationDims.empty() || 1715 !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) { 1716 continue; 1717 } 1718 1719 Optional<SmallVector<Value>> replacements = 1720 collapseGenericOpIterationDims(genericOp, collapsableIterationDims, 1721 opOperand, rewriter); 1722 if (!replacements) { 1723 return rewriter.notifyMatchFailure( 1724 genericOp, "failed to do the fusion by collapsing transformation"); 1725 } 1726 1727 rewriter.replaceOp(genericOp, replacements.getValue()); 1728 return success(); 1729 } 1730 return failure(); 1731 } 1732 1733 private: 1734 ControlFusionFn controlFoldingReshapes; 1735 }; 1736 } // namespace 1737 1738 //===---------------------------------------------------------------------===// 1739 // Methods and patterns to convert tensor.expand_shape -> linalg.generic 1740 // into linalg.generic -> tensor.expand_shape, i.e. push the reshape down. 1741 //===---------------------------------------------------------------------===// 1742 1743 // TODO(ravishankarm): This pattern is to be deprecated in favor of fusion by 1744 // collapsing that provides a more general functionality. This pattern is very 1745 // specific to a particular use case. The fusion by collapsing can provide the 1746 // same control to clients using the control function there. 1747 1748 static SmallVector<ReassociationIndices> 1749 getReassociationIndices(ArrayRef<AffineMap> maps) { 1750 SmallVector<ReassociationIndices> reassociation; 1751 for (AffineMap map : maps) { 1752 ReassociationIndices indices; 1753 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 1754 unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition(); 1755 indices.push_back(pos); 1756 } 1757 reassociation.push_back(indices); 1758 } 1759 return reassociation; 1760 } 1761 1762 namespace { 1763 /// Pattern to move rank reducing reshape after an elementwise linalg generic 1764 /// op. This is useful to expose more fusion opportunities between named ops and 1765 /// generic ops. This can only be done if there is no broadcast or permuation 1766 /// within the dimensions we need to merge. 1767 /// 1768 /// For example, 1769 /// 1770 /// %0 = tensor.expand_shape %A [[0, 1], [2]] 1771 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 1772 /// %2 = linalg.generic {indexing_maps = [ 1773 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1774 /// affine_map<(d0, d1, d2) -> (d2)>, 1775 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = 1776 /// ["parallel", "parallel", "parallel"]} { 1777 /// } -> tensor<112x112x16xf32> 1778 /// 1779 /// into 1780 /// 1781 /// %2 = linalg.generic {indexing_maps = [ 1782 /// affine_map<(d0, d1) -> (d0, d1)>, 1783 /// affine_map<(d0, d1) -> (d1)>, 1784 /// affine_map<(d0, d1) -> (d0, d1)>], 1785 /// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 1786 /// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) { 1787 /// } -> tensor<12544x16xf32> 1788 /// %3 = tensor.expand_shape %2 [[0, 1], [2]] 1789 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 1790 struct PushExpandingReshape : public OpRewritePattern<GenericOp> { 1791 using OpRewritePattern<GenericOp>::OpRewritePattern; 1792 1793 LogicalResult matchAndRewrite(GenericOp genericOp, 1794 PatternRewriter &rewriter) const override { 1795 // Only apply to elementwise linalg on tensor. 1796 if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() || 1797 genericOp.getNumParallelLoops() != genericOp.getNumLoops()) 1798 return failure(); 1799 // Only support identity output maps. It could be extended to permuations if 1800 // needed. 1801 if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) { 1802 return !genericOp.getTiedIndexingMap(opOperand).isIdentity(); 1803 })) 1804 return failure(); 1805 int64_t destRank = genericOp.getNumParallelLoops(); 1806 SmallVector<Value> newOperands = genericOp.getInputOperands(); 1807 tensor::ExpandShapeOp reshapeFound; 1808 // 1. Look for tensor_expand_shape operands and figure out save the 1809 // dimensions merged. 1810 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 1811 for (const auto &en : llvm::enumerate(inputOperands)) { 1812 auto reshapeOp = 1813 en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>(); 1814 if (!reshapeOp) 1815 continue; 1816 // TODO: We could support non-identity map as long as the merged 1817 // dimensions are still contiguous. 1818 if (!genericOp.getTiedIndexingMap(en.value()).isIdentity()) 1819 continue; 1820 if (reshapeFound) { 1821 // Only support a second reshape op if it has the same reassociate maps. 1822 if (reshapeFound.getReassociationMaps() == 1823 reshapeOp.getReassociationMaps()) 1824 newOperands[en.index()] = reshapeOp.src(); 1825 continue; 1826 } 1827 reshapeFound = reshapeOp; 1828 newOperands[en.index()] = reshapeOp.src(); 1829 } 1830 if (!reshapeFound) 1831 return failure(); 1832 1833 // Calculate the reassociation indices and rassociated reverse map. 1834 SmallVector<ReassociationIndices> reassociation = 1835 getReassociationIndices(reshapeFound.getReassociationMaps()); 1836 SmallVector<unsigned> remap(destRank); 1837 for (auto &indices : llvm::enumerate(reassociation)) { 1838 for (int64_t index : indices.value()) { 1839 remap[index] = indices.index(); 1840 } 1841 } 1842 // 2. Verify that we can merge the dimensions in the linalg and that we 1843 // don't need to create new reshapes operands. Inserting new reshape 1844 // operands would defeat the purpose of the transformation. 1845 for (const auto &en : llvm::enumerate(inputOperands)) { 1846 if (en.value()->get() == newOperands[en.index()]) { 1847 AffineMap map = genericOp.getTiedIndexingMap(en.value()); 1848 for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { 1849 if (reassociation[remap[map.getDimPosition(i)]].size() > 1) 1850 return failure(); 1851 } 1852 } 1853 } 1854 1855 // 3. Calculate the affine map remapping and the reassociation to apply to 1856 // output tensors. 1857 SmallVector<AffineMap> newMaps; 1858 unsigned newRank = reassociation.size(); 1859 for (auto map : genericOp.getIndexingMaps()) { 1860 SmallVector<AffineExpr> newExprs; 1861 for (auto expr : map.getResults()) { 1862 unsigned position = expr.template cast<AffineDimExpr>().getPosition(); 1863 // Skip dimension merged except for the last of the group. 1864 if (reassociation[remap[position]].back() == position) { 1865 newExprs.push_back( 1866 getAffineDimExpr(remap[position], genericOp.getContext())); 1867 } 1868 } 1869 newMaps.push_back( 1870 AffineMap::get(newRank, 0, newExprs, genericOp.getContext())); 1871 } 1872 1873 // 4. Reshape the output tensors. 1874 SmallVector<Value> newOutputs; 1875 SmallVector<Type> newOutputTypes; 1876 for (auto output : genericOp.outputs()) { 1877 auto newOutputType = RankedTensorType::get( 1878 reshapeFound.getSrcType().getShape(), 1879 output.getType().template cast<RankedTensorType>().getElementType()); 1880 Value newOutput = rewriter.create<tensor::CollapseShapeOp>( 1881 genericOp->getLoc(), newOutputType, output, reassociation); 1882 newOutputTypes.push_back(newOutputType); 1883 newOutputs.push_back(newOutput); 1884 } 1885 // 5. Create a new generic op with lowerer rank. 1886 SmallVector<StringRef> iteratorTypes(newRank, 1887 getParallelIteratorTypeName()); 1888 auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes, 1889 newOperands, newOutputs, newMaps, 1890 iteratorTypes); 1891 rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), 1892 newOp.region().begin()); 1893 // 6. Reshape the so that the type matches the uses. 1894 SmallVector<Value> newResults; 1895 for (const auto &result : llvm::enumerate(newOp->getResults())) { 1896 newResults.push_back(rewriter.create<tensor::ExpandShapeOp>( 1897 genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()], 1898 result.value(), reassociation)); 1899 } 1900 rewriter.replaceOp(genericOp, newResults); 1901 return success(); 1902 } 1903 }; 1904 } // namespace 1905 1906 //===---------------------------------------------------------------------===// 1907 // Methods and patterns that fuse constants with linalg.generic operations. 1908 //===---------------------------------------------------------------------===// 1909 1910 namespace { 1911 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not 1912 /// handle cases where the constant is not single-valued. 1913 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> { 1914 public: 1915 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1) 1916 : OpRewritePattern<GenericOp>(context, benefit) {} 1917 1918 LogicalResult matchAndRewrite(GenericOp genericOp, 1919 PatternRewriter &rewriter) const override { 1920 if (!genericOp.hasTensorSemantics()) 1921 return failure(); 1922 for (OpOperand *opOperand : genericOp.getInputOperands()) { 1923 Operation *def = opOperand->get().getDefiningOp(); 1924 Attribute constantAttr; 1925 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { 1926 { 1927 DenseElementsAttr splatAttr; 1928 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) && 1929 splatAttr.isSplat() && 1930 splatAttr.getType().getElementType().isIntOrFloat()) { 1931 constantAttr = splatAttr.getSplatValue<Attribute>(); 1932 return true; 1933 } 1934 } 1935 { 1936 IntegerAttr intAttr; 1937 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) { 1938 constantAttr = intAttr; 1939 return true; 1940 } 1941 } 1942 { 1943 FloatAttr floatAttr; 1944 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) { 1945 constantAttr = floatAttr; 1946 return true; 1947 } 1948 } 1949 return false; 1950 }; 1951 1952 auto resultValue = opOperand->get().dyn_cast<OpResult>(); 1953 if (!def || !resultValue || !isScalarOrSplatConstantOp(def)) 1954 continue; 1955 1956 // The operands and the indexing_maps of the fused operation the same as 1957 // the operands and indexing_maps of the generic operations with the 1958 // values at the constant index dropped. 1959 SmallVector<AffineMap> fusedIndexMaps; 1960 SmallVector<Value> fusedOperands; 1961 SmallVector<Location> fusedLocs{genericOp.getLoc()}; 1962 fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); 1963 fusedOperands.reserve(genericOp.getNumInputs()); 1964 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); 1965 for (OpOperand *inputOperand : genericOp.getInputOperands()) { 1966 if (inputOperand == opOperand) 1967 continue; 1968 Value inputValue = inputOperand->get(); 1969 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); 1970 fusedOperands.push_back(inputValue); 1971 fusedLocs.push_back(inputValue.getLoc()); 1972 } 1973 for (OpOperand *outputOperand : genericOp.getOutputOperands()) 1974 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); 1975 1976 // Check if the operation shapes to loops map is computable. 1977 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1978 return rewriter.notifyMatchFailure( 1979 genericOp, "fused op loop bound computation failed"); 1980 } 1981 1982 // Create a constant scalar value from the splat constant. 1983 Value scalarConstant = rewriter.create<arith::ConstantOp>( 1984 def->getLoc(), constantAttr, constantAttr.getType()); 1985 1986 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 1987 auto fusedOp = rewriter.create<GenericOp>( 1988 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), 1989 /*inputs=*/fusedOperands, 1990 /*outputs=*/outputOperands, 1991 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1992 genericOp.iterator_types(), 1993 /*doc=*/nullptr, 1994 /*library_call=*/nullptr); 1995 1996 // Map the block argument corresponding to the replaced argument with the 1997 // scalar constant. 1998 Region ®ion = genericOp->getRegion(0); 1999 Block &entryBlock = *region.begin(); 2000 BlockAndValueMapping mapping; 2001 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), 2002 scalarConstant); 2003 Region &fusedRegion = fusedOp->getRegion(0); 2004 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), 2005 mapping); 2006 rewriter.replaceOp(genericOp, fusedOp->getResults()); 2007 return success(); 2008 } 2009 return failure(); 2010 } 2011 }; 2012 2013 } // namespace 2014 2015 //===---------------------------------------------------------------------===// 2016 // Miscellaneous patterns that help fusion. 2017 //===---------------------------------------------------------------------===// 2018 2019 namespace { 2020 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if 2021 /// the value of the `outs` operand is not used within the op. This is only 2022 /// implemented for `linalg.generic` operations for now, but should hold for all 2023 /// linalg structured ops. 2024 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> { 2025 using OpRewritePattern<GenericOp>::OpRewritePattern; 2026 2027 LogicalResult matchAndRewrite(GenericOp op, 2028 PatternRewriter &rewriter) const override { 2029 rewriter.startRootUpdate(op); 2030 bool modifiedOutput = false; 2031 Location loc = op.getLoc(); 2032 for (OpOperand *opOperand : op.getOutputOperands()) { 2033 if (!op.payloadUsesValueFromOperand(opOperand)) { 2034 Value operandVal = opOperand->get(); 2035 auto operandType = operandVal.getType().dyn_cast<RankedTensorType>(); 2036 if (!operandType) 2037 continue; 2038 2039 // If outs is sparse, leave it to the sparse compiler. 2040 if (sparse_tensor::getSparseTensorEncoding(operandVal.getType())) 2041 continue; 2042 2043 // If outs is already an `init_tensor` operation, nothing to do. 2044 auto definingOp = operandVal.getDefiningOp<InitTensorOp>(); 2045 if (definingOp) 2046 continue; 2047 modifiedOutput = true; 2048 SmallVector<Value> dynamicDims; 2049 for (const auto &dim : llvm::enumerate(operandType.getShape())) { 2050 if (dim.value() != ShapedType::kDynamicSize) 2051 continue; 2052 dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>( 2053 loc, operandVal, dim.index())); 2054 } 2055 Value initTensor = rewriter.create<InitTensorOp>( 2056 loc, dynamicDims, operandType.getShape(), 2057 operandType.getElementType()); 2058 op->setOperand(opOperand->getOperandNumber(), initTensor); 2059 } 2060 } 2061 if (!modifiedOutput) { 2062 rewriter.cancelRootUpdate(op); 2063 return failure(); 2064 } 2065 rewriter.finalizeRootUpdate(op); 2066 return success(); 2067 } 2068 }; 2069 2070 /// Fold linalg.fill into linalg.generic 2071 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> { 2072 using OpRewritePattern<GenericOp>::OpRewritePattern; 2073 2074 LogicalResult matchAndRewrite(GenericOp genericOp, 2075 PatternRewriter &rewriter) const override { 2076 if (!genericOp.hasTensorSemantics()) 2077 return failure(); 2078 bool fillFound = false; 2079 Block &payload = genericOp.region().front(); 2080 for (OpOperand *opOperand : genericOp.getInputOperands()) { 2081 if (!genericOp.payloadUsesValueFromOperand(opOperand)) 2082 continue; 2083 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>(); 2084 if (!fillOp) 2085 continue; 2086 fillFound = true; 2087 payload.getArgument(opOperand->getOperandNumber()) 2088 .replaceAllUsesWith(fillOp.value()); 2089 } 2090 return success(fillFound); 2091 } 2092 }; 2093 } // namespace 2094 //===---------------------------------------------------------------------===// 2095 // Methods that add patterns described in this file to a pattern list. 2096 //===---------------------------------------------------------------------===// 2097 2098 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( 2099 RewritePatternSet &patterns) { 2100 patterns.add< 2101 FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>, 2102 FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>, 2103 FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>>( 2104 patterns.getContext()); 2105 } 2106 2107 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( 2108 RewritePatternSet &patterns) { 2109 patterns 2110 .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>, 2111 FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>, 2112 FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>>( 2113 patterns.getContext()); 2114 } 2115 2116 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( 2117 RewritePatternSet &patterns, 2118 const ControlFusionFn &controlFoldingReshapes) { 2119 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(), 2120 controlFoldingReshapes); 2121 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), 2122 controlFoldingReshapes); 2123 } 2124 2125 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( 2126 RewritePatternSet &patterns, 2127 const ControlFusionFn &controlFoldingReshapes) { 2128 patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(), 2129 controlFoldingReshapes); 2130 } 2131 2132 void mlir::linalg::populateElementwiseOpsFusionPatterns( 2133 RewritePatternSet &patterns, 2134 const ControlFusionFn &controlElementwiseOpsFusion) { 2135 auto *context = patterns.getContext(); 2136 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion); 2137 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant, 2138 RemoveOutsDependency>(context); 2139 } 2140 2141 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { 2142 auto *context = patterns.getContext(); 2143 patterns.add<PushExpandingReshape>(context); 2144 } 2145 2146 //===---------------------------------------------------------------------===// 2147 // Passes 2148 //===---------------------------------------------------------------------===// 2149 2150 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, 2151 OpOperand &consumer) { 2152 if (auto producerCollapseOp = 2153 dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) { 2154 return !isUnitDimExpansionOnly(producerCollapseOp); 2155 } 2156 if (auto consumerExpandOp = 2157 dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) { 2158 return !isUnitDimExpansionOnly(consumerExpandOp); 2159 } 2160 return true; 2161 } 2162 2163 namespace { 2164 2165 /// Pass that fuses generic ops on tensors. Used only for testing. 2166 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the 2167 // patterns added here heavily depends on the cost function used. Having an 2168 // opinionated pass of this form is not recommended. Deprecate this pass in 2169 // favor of test passes that check the functionality of each of the patterns 2170 // added here individually. 2171 struct LinalgElementwiseOpFusionPass 2172 : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> { 2173 void runOnOperation() override { 2174 Operation *op = getOperation(); 2175 MLIRContext *context = op->getContext(); 2176 RewritePatternSet patterns(context); 2177 2178 // Add folding with reshape by expansion patterns. 2179 ControlFusionFn defaultControlFn = [](const OpResult &producer, 2180 const OpOperand &consumer) { 2181 return producer.hasOneUse(); 2182 }; 2183 2184 // Add elementwise op fusion patterns. 2185 populateElementwiseOpsFusionPatterns(patterns, defaultControlFn); 2186 2187 populateFoldReshapeOpsByExpansionPatterns( 2188 patterns, 2189 allowFoldingUnitDimReshapes ? defaultControlFn : skipUnitDimReshape); 2190 2191 // Add the sparse tensor rewriting patterns. 2192 populateSparseTensorRewriting(patterns); 2193 2194 // General canonicalization patterns. 2195 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 2196 GenericOp::getCanonicalizationPatterns(patterns, context); 2197 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); 2198 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); 2199 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns( 2200 patterns); 2201 2202 // Add constant folding patterns. 2203 populateConstantFoldLinalgOperations(patterns, defaultControlFn); 2204 2205 // Use TopDownTraversal for compile time reasons 2206 GreedyRewriteConfig grc; 2207 grc.useTopDownTraversal = true; 2208 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), 2209 grc); 2210 } 2211 }; 2212 2213 /// Pass to test folding of reshape ops with generic ops by linearization. 2214 struct FoldReshapeOpsByLinearizationPass 2215 : public LinalgFoldReshapeOpsByLinearizationBase< 2216 FoldReshapeOpsByLinearizationPass> { 2217 void runOnOperation() override { 2218 Operation *op = getOperation(); 2219 RewritePatternSet patterns(op->getContext()); 2220 populateFoldReshapeOpsByLinearizationPatterns(patterns); 2221 if (allowFoldingUnitDimReshapes) { 2222 populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns); 2223 } 2224 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); 2225 } 2226 }; 2227 2228 } // namespace 2229 2230 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() { 2231 return std::make_unique<LinalgElementwiseOpFusionPass>(); 2232 } 2233 2234 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() { 2235 return std::make_unique<FoldReshapeOpsByLinearizationPass>(); 2236 } 2237