1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion on tensors operations pass. 10 // 11 //===----------------------------------------------------------------------===// 12 #include <utility> 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 28 using namespace mlir; 29 using namespace mlir::linalg; 30 31 //===---------------------------------------------------------------------===// 32 // Methods and patterns that fuse elementwise `linalg.generic` operations. 33 //===---------------------------------------------------------------------===// 34 35 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 36 /// the `producer` to use in the fused operation given the indexing map of the 37 /// result of the producer in the consumer. 38 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 39 OpOperand *producerOpOperand, AffineMap producerResultIndexMap, 40 AffineMap fusedConsumerArgIndexMap) { 41 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 42 // from consumer loop -> consumer arg tensor index/producer result tensor 43 // index. The fused loop is same as the consumer loop. For each producer arg 44 // the indexing map to be computed is a map from consumer loop -> producer 45 // arg tensor index. 46 // producerResultIndexMap is a map from producer loop -> tensor index. 47 // Compute the inverse to get map from tensor index -> producer loop. 48 // The inverse is a map from producer result tensor index -> producer loop. 49 AffineMap invProducerResultIndexMap = 50 inversePermutation(producerResultIndexMap); 51 assert(invProducerResultIndexMap && 52 "expected producer result indexing map to be invertible"); 53 54 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner()); 55 // argMap is a map from producer loop -> producer arg tensor index. 56 AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); 57 58 // Compose argMap with invProducerResultIndexMap to get a map from 59 // producer result tensor index -> producer arg tensor index. 60 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 61 62 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 63 // consumer loop/ fused loop -> producer arg tensor index. 64 return t1.compose(fusedConsumerArgIndexMap); 65 } 66 67 /// Conditions for elementwise fusion of generic operations. 68 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, 69 OpOperand *consumerOpOperand) { 70 // Producer and consumer must have tensor semantics. 71 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 72 return false; 73 74 // Verify that 75 // - the producer has all "parallel" iterator type. 76 if (producer.getNumParallelLoops() != producer.getNumLoops()) 77 return false; 78 79 // Only allow fusing the producer of an input operand for now. 80 // TODO: allow fusing the producer of an output operand. 81 if (!consumer.isInputTensor(consumerOpOperand)) 82 return false; 83 84 // Get the consumer index map. The number of results of the consumer index 85 // map must match the number of loops of the producer. 86 AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); 87 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 88 return false; 89 90 // Currently support only operations with single result. 91 if (producer.getNumOutputs() != 1) 92 return false; 93 94 // Finally the index_map for the result must be invertible. For now just 95 // verify it is a permutation. 96 AffineMap producerResultIndexMap = 97 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 98 if (!producerResultIndexMap.isPermutation()) 99 return false; 100 101 // Ensure that the fusion does not remove size information required to 102 // get the loop bounds. For non-reduction generics, this is trivially the 103 // case due to the output operand. For reductions, we need to check that after 104 // the fusion, each loop dimension has at least one input that defines it. 105 if ((consumer.getNumReductionLoops())) { 106 BitVector coveredDims(consumer.getNumLoops(), false); 107 108 auto addToCoveredDims = [&](AffineMap map) { 109 for (auto result : map.getResults()) 110 if (auto dimExpr = result.dyn_cast<AffineDimExpr>()) 111 coveredDims[dimExpr.getPosition()] = true; 112 }; 113 114 for (auto pair : 115 llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) { 116 Value operand = std::get<0>(pair); 117 if (operand == consumerOpOperand->get()) 118 continue; 119 AffineMap operandMap = std::get<1>(pair); 120 addToCoveredDims(operandMap); 121 } 122 123 for (OpOperand *operand : producer.getInputOperands()) { 124 AffineMap newIndexingMap = 125 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 126 operand, producerResultIndexMap, consumerIndexMap); 127 addToCoveredDims(newIndexingMap); 128 } 129 if (!coveredDims.all()) 130 return false; 131 } 132 133 return true; 134 } 135 136 /// Generate the region of the fused tensor operation. The region of the fused 137 /// op must be empty. 138 static void 139 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, 140 AffineMap consumerToProducerLoopsMap, 141 OpOperand *consumerOpOperand, 142 unsigned nloops) { 143 auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp()); 144 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 145 // Build the region of the fused op. 146 Block &producerBlock = producer->getRegion(0).front(); 147 Block &consumerBlock = consumer->getRegion(0).front(); 148 Block *fusedBlock = new Block(); 149 fusedOp.region().push_back(fusedBlock); 150 BlockAndValueMapping mapper; 151 OpBuilder::InsertionGuard guard(rewriter); 152 rewriter.setInsertionPointToStart(fusedBlock); 153 154 // 2. Add an index operation for every fused loop dimension and use the 155 // `consumerToProducerLoopsMap` to map the producer indices. 156 if (producer.hasIndexSemantics()) { 157 // Add an index operation for every fused loop dimension. 158 unsigned numFusedOpLoops = 159 std::max(producer.getNumLoops(), consumer.getNumLoops()); 160 SmallVector<Value> fusedIndices; 161 fusedIndices.reserve(numFusedOpLoops); 162 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops), 163 std::back_inserter(fusedIndices), [&](uint64_t dim) { 164 return rewriter.create<IndexOp>(producer.getLoc(), dim); 165 }); 166 for (IndexOp indexOp : 167 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) { 168 Value newIndex = rewriter.create<mlir::AffineApplyOp>( 169 producer.getLoc(), 170 consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices); 171 mapper.map(indexOp.getResult(), newIndex); 172 } 173 } 174 // TODO: allow fusing the producer of an output operand. 175 assert(consumer.isInputTensor(consumerOpOperand) && 176 "expected producer of input operand"); 177 // 3. Consumer input operands up to consumerIdx (exclusive). 178 for (BlockArgument bbArg : consumerBlock.getArguments().take_front( 179 consumerOpOperand->getOperandNumber())) // input assumption. 180 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 181 182 // Replacing consumerIdx requires getting the cloned, yielded, value from 183 // the (cloned) producer block. This happens in step 9. 184 185 // 4. Splice in producer's input operands. 186 for (BlockArgument bbArg : 187 producerBlock.getArguments().take_front(producer.getNumInputs())) 188 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 189 190 // 4.b. Producer output operand/map that is fused needs to be mapped to the 191 // producer bbArg if it is an "initTensor" (i.e. its value is actually read). 192 assert(producer->getNumResults() == 1 && "expected single result producer"); 193 if (producer.isInitTensor(producer.getOutputOperand(0))) { 194 BlockArgument bbArg = producerBlock.getArguments() 195 .drop_front(producer.getNumInputs()) 196 // TODO: bbArg index of 197 .front(); 198 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 199 } 200 // 5. Remaining consumer's input operands (drop past index `consumerIdx`). 201 for (BlockArgument bbArg : 202 consumerBlock.getArguments() 203 .take_front(consumer.getNumInputs()) 204 .drop_front(consumerOpOperand->getOperandNumber() + 1)) 205 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 206 // 6. All of consumer's output operands. 207 for (BlockArgument bbArg : 208 consumerBlock.getArguments().take_back(consumer.getNumOutputs())) 209 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 210 // 7. All of producer's output operands except the one fused. 211 // TODO: allow fusion of multi-result producers. 212 assert(producer->getNumResults() == 1 && "expected single result producer"); 213 214 // 8. Clone all producer operations except for the yield and index operations 215 // to the fused operation. 216 for (auto &op : producerBlock.without_terminator()) { 217 if (!isa<IndexOp>(op)) 218 rewriter.clone(op, mapper); 219 } 220 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just 221 // forward the yield operand. 222 auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator()); 223 // TODO: allow fusion of multi-result producers. 224 assert(producer->getNumResults() == 1 && "expected single result producer"); 225 unsigned producerResultNumber = 0; 226 Value replacement = 227 mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); 228 // Sanity checks, if replacement is not already in the mapper then it must be 229 // produced outside. 230 if (replacement == yieldOp.getOperand(producerResultNumber)) { 231 if (auto bb = replacement.dyn_cast<BlockArgument>()) 232 assert(bb.getOwner() != &producerBlock && 233 "yielded block argument must have been mapped"); 234 else 235 assert(!producer->isAncestor(replacement.getDefiningOp()) && 236 "yielded value must have been mapped"); 237 } 238 mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), 239 replacement); 240 // 10. Clone operations from the consumer to the fused op. 241 for (auto &op : consumerBlock.getOperations()) 242 rewriter.clone(op, mapper); 243 244 // Sanity checks. 245 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && 246 "Ill-formed GenericOp region"); 247 } 248 249 static Optional<SmallVector<Value>> 250 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, 251 const ControlFusionFn &controlFn, 252 PatternRewriter &rewriter) { 253 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 254 if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || 255 !controlFn(producer->getResult(0), *consumerOpOperand)) 256 return llvm::None; 257 258 // TODO: allow fusing the producer of an output operand. 259 assert(consumer.isInputTensor(consumerOpOperand) && 260 "expected producer of input operand"); 261 262 // Compute the fused operands list and indexing maps. 263 SmallVector<Value> fusedOperands; 264 SmallVector<AffineMap> fusedIndexMaps; 265 fusedOperands.reserve(producer->getNumOperands() + 266 consumer->getNumOperands()); 267 fusedIndexMaps.reserve(producer->getNumOperands() + 268 consumer->getNumOperands()); 269 // In the following, numbering matches that of `generateFusedTensorOpRegion`. 270 // 3. Consumer input operands/maps up to consumerIdx (exclusive). 271 SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands(); 272 SmallVector<OpOperand *>::iterator it = 273 llvm::find(consumerInputs, consumerOpOperand); 274 assert(it != consumerInputs.end() && "expected to find the consumer operand"); 275 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { 276 fusedOperands.push_back(opOperand->get()); 277 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 278 } 279 // 4. Splice in producer's input operands/maps. 280 assert(producer->getNumResults() == 1 && "expected single result producer"); 281 AffineMap producerResultIndexMap = 282 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 283 for (OpOperand *opOperand : producer.getInputOperands()) { 284 fusedOperands.push_back(opOperand->get()); 285 // Compute indexing maps for the producer args in the fused operation. 286 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 287 opOperand, producerResultIndexMap, 288 consumer.getTiedIndexingMap(consumerOpOperand)); 289 fusedIndexMaps.push_back(map); 290 } 291 // 4.b. Producer output operand/map that is fused needs to be passed if it is 292 // an "initTensor" (i.e. its value is actually read). 293 assert(producer->getNumResults() == 1 && "expected single result producer"); 294 if (producer.isInitTensor(producer.getOutputOperand(0))) { 295 fusedOperands.push_back(producer.getOutputOperand(0)->get()); 296 // Compute indexing maps for the producer args in the fused operation. 297 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 298 producer.getOutputOperand(0), producerResultIndexMap, 299 consumer.getTiedIndexingMap(consumerOpOperand)); 300 fusedIndexMaps.push_back(map); 301 } 302 // 5. Remaining consumer's input operands/maps (drop past index 303 // `consumerIdx`). 304 for (OpOperand *opOperand : 305 llvm::make_range(std::next(it), consumerInputs.end())) { 306 fusedOperands.push_back(opOperand->get()); 307 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 308 } 309 // 6. All of consumer's output operands (skip operands: added by the builder). 310 for (OpOperand *opOperand : consumer.getOutputOperands()) 311 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 312 // 7. All of producer's output operands/maps except the one fused. 313 // TODO: allow fusion of multi-result producers. 314 assert(producer->getNumResults() == 1 && "expected single result producer"); 315 316 // Generate the fused op. 317 SmallVector<Value> consumerOutputs = consumer.getOutputOperands(); 318 auto fusedOp = rewriter.create<GenericOp>( 319 consumer.getLoc(), consumer->getResultTypes(), 320 /*inputs=*/fusedOperands, 321 // TODO: handle outputs. 322 consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 323 consumer.iterator_types(), 324 /*doc=*/nullptr, 325 /*library_call=*/nullptr); 326 if (!fusedOp.getShapesToLoopsMap()) { 327 // Fused op has invalid indexing maps. Typically this means something is off 328 // in the input, but going ahead here would result in verification errors. 329 // So cleanup and abort. 330 rewriter.eraseOp(fusedOp); 331 return llvm::None; 332 } 333 334 // Construct an AffineMap from consumer loops to producer loops. 335 // consumer loop -> tensor index 336 AffineMap consumerResultIndexMap = 337 consumer.getTiedIndexingMap(consumerOpOperand); 338 // tensor index -> producer loop 339 AffineMap invProducerResultIndexMap = 340 inversePermutation(producerResultIndexMap); 341 assert(invProducerResultIndexMap && 342 "expected producer result indexig map to be invertible"); 343 // consumer loop -> producer loop 344 AffineMap consumerToProducerLoopsMap = 345 invProducerResultIndexMap.compose(consumerResultIndexMap); 346 347 generateFusedElementwiseOpRegion(rewriter, fusedOp, 348 consumerToProducerLoopsMap, 349 consumerOpOperand, consumer.getNumLoops()); 350 return SmallVector<Value>(fusedOp->getResults()); 351 } 352 353 static Optional<SmallVector<Value>> 354 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, 355 GenericOp producer, const ControlFusionFn &controlFn) { 356 if (producer->getNumResults() != 1) 357 return llvm::None; 358 359 return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, 360 rewriter); 361 } 362 363 namespace { 364 /// Patterns to fuse a generic op, with the producer of its operands. 365 class FuseElementwiseOps : public OpRewritePattern<GenericOp> { 366 public: 367 FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun, 368 PatternBenefit benefit = 1) 369 : OpRewritePattern<GenericOp>(context, benefit), 370 controlFn(std::move(fun)) {} 371 372 LogicalResult matchAndRewrite(GenericOp genericOp, 373 PatternRewriter &rewriter) const override { 374 // Find the first operand that is defined by another generic op on tensors. 375 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 376 auto producer = 377 dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp()); 378 if (!producer || !producer.hasTensorSemantics()) 379 continue; 380 Optional<SmallVector<Value>> fusedOpResults = 381 fuseElementwiseOps(rewriter, opOperand, producer, controlFn); 382 if (fusedOpResults) { 383 rewriter.replaceOp(genericOp, *fusedOpResults); 384 return success(); 385 } 386 } 387 return failure(); 388 } 389 390 private: 391 ControlFusionFn controlFn; 392 }; 393 } // namespace 394 395 //===---------------------------------------------------------------------===// 396 // Methods and patterns that fuse reshape ops with elementwise operations by 397 // expanding the dimensionality of the elementwise operations. 398 //===---------------------------------------------------------------------===// 399 400 /// Conditions for folding a generic operation with a reshape op by expanding 401 /// the iteration space dimensionality for tensor operations. These are 402 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the 403 /// following fusion pattern. 404 /// 405 /// Consider 406 /// 407 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>) 408 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 409 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 410 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] 411 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] 412 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 413 /// 414 /// The reshape can be folded into the `genericOp` if its loop dimensionality 415 /// is increased to match the result (operand) of the tensor_expand_shape. 416 /// The indexing_map of the fused tensor in the `genericOp` and the 417 /// reassociation map helps compute the indexing maps of the modified op. 418 /// For the above example, based on the reassociation map it 419 /// can be concluded that 420 /// 421 /// - The loop used to access the first dimension of the fused tensor is split 422 /// into two. 423 /// - The loop used to access the second dimension of the fused tensor is kept 424 /// as is. 425 /// - The loop used to access the third dimension of the fused tensor is split 426 /// into three. 427 /// 428 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified 429 /// op, then 430 /// 431 /// d0 -> e0, e1 432 /// d1 -> e2, e3, e4 433 /// d2 -> e5 434 /// 435 /// substituting this, the generic op can be rewritten as 436 /// 437 /// %d = linalg.generic ins(%0, %1 : ) 438 /// indexing_maps = 439 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, 440 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, 441 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] 442 /// 443 /// Since operands to the linalg generic are now 5D, reshapes can be introduced 444 /// to make it consistent 445 /// 446 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]] 447 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 448 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]] 449 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 450 /// 451 /// The added reshapes are again expanding patterns, so they will get fused 452 /// with its producers if possible. 453 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, 454 OpOperand *fusableOpOperand) { 455 // Is fusable only if: 456 // - All the indexing maps for operands and results are projected 457 // permutations. 458 // - The fused tensor is not a scalar. 459 // - All the loops are parallel loops. 460 return genericOp.hasTensorSemantics() && 461 llvm::all_of(genericOp.indexing_maps().getValue(), 462 [](Attribute attr) { 463 return attr.cast<AffineMapAttr>() 464 .getValue() 465 .isProjectedPermutation(); 466 }) && 467 genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && 468 llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { 469 return attr.cast<StringAttr>().getValue() == 470 getParallelIteratorTypeName(); 471 }); 472 } 473 474 namespace { 475 /// Information needed to expand a generic operation to fold the reshape with 476 /// it. 477 class ExpansionInfo { 478 public: 479 // Computes the mapping from original dimensions of the op to the dimensions 480 // of the expanded op given the `indexingMap` of the fused operand/result of 481 // the generic op, the `reassocationMaps` of the reshape op and the shape of 482 // the expanded op. 483 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, 484 ArrayRef<AffineMap> reassociationMaps, 485 ArrayRef<int64_t> expandedShape, 486 ArrayRef<int64_t> collapsedShape, 487 PatternRewriter &rewriter); 488 unsigned getOrigOpNumDims() const { return reassociation.size(); } 489 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } 490 ReassociationIndicesRef getExpandedDims(unsigned i) const { 491 return reassociation[i]; 492 } 493 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { 494 return expandedShapeMap[i]; 495 } 496 ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; } 497 498 private: 499 /// Reassociation from the dimensions in the original operation to the 500 /// dimension of the expanded operation. 501 SmallVector<ReassociationIndices> reassociation; 502 /// Mapping from extent of loops in the original operation, to the extent of 503 /// loops in the expanded operation. 504 SmallVector<SmallVector<int64_t>> expandedShapeMap; 505 /// Extent of the loop in the original operation. 506 SmallVector<int64_t> originalLoopExtent; 507 unsigned expandedOpNumDims; 508 }; 509 } // namespace 510 511 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, 512 OpOperand *fusableOpOperand, 513 ArrayRef<AffineMap> reassociationMaps, 514 ArrayRef<int64_t> expandedShape, 515 ArrayRef<int64_t> collapsedShape, 516 PatternRewriter &rewriter) { 517 if (reassociationMaps.empty()) 518 return failure(); 519 AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); 520 521 Optional<SmallVector<int64_t, 4>> originalLoopRange = 522 linalgOp.getStaticLoopRanges(); 523 if (!originalLoopRange) 524 return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range"); 525 originalLoopExtent.assign(originalLoopRange->begin(), 526 originalLoopRange->end()); 527 528 reassociation.clear(); 529 expandedShapeMap.clear(); 530 // Compute the number of dimension in the expanded op that correspond to each 531 // dimension of the original op. 532 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1); 533 expandedShapeMap.resize(fusedIndexMap.getNumDims()); 534 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { 535 unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition(); 536 AffineMap foldedDims = reassociationMaps[resultExpr.index()]; 537 numExpandedDims[pos] = foldedDims.getNumResults(); 538 ArrayRef<int64_t> shape = 539 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); 540 expandedShapeMap[pos].assign(shape.begin(), shape.end()); 541 } 542 // The remaining dimensions remain the same. 543 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims())) 544 if (expandedShapeMap[i].empty()) 545 expandedShapeMap[i] = {originalLoopExtent[i]}; 546 547 // Compute reassociation map from the original op to the expanded op. 548 unsigned sum = 0; 549 reassociation.reserve(fusedIndexMap.getNumDims()); 550 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) { 551 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value()); 552 reassociation.emplace_back(seq.begin(), seq.end()); 553 sum += numFoldedDim.value(); 554 } 555 expandedOpNumDims = sum; 556 return success(); 557 } 558 559 /// Epanding the body of a linalg operation requires adaptations of the accessed 560 /// loop indices. Specifically, access of indices in the original operation need 561 /// to be replaced with linearizations of indices in the expanded op. That 562 /// requires the shape of the expanded dimensions to be static (at least all but 563 /// the most significant). For now check that these are all statically sized. 564 /// Note that this could be extended to handle dynamic case, but the 565 /// implementation below uses `affine.apply` which seems to have issues when the 566 /// shapes are not static. 567 static LogicalResult isGenericOpExpandable(GenericOp genericOp, 568 const ExpansionInfo &expansionInfo, 569 PatternRewriter &rewriter) { 570 if (!genericOp.hasIndexSemantics()) 571 return success(); 572 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 573 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 574 if (expandedShape.size() == 1) 575 continue; 576 for (int64_t shape : expandedShape.drop_front()) { 577 if (ShapedType::isDynamic(shape)) { 578 return rewriter.notifyMatchFailure( 579 genericOp, "cannot expand due to index semantics and dynamic dims"); 580 } 581 } 582 } 583 return success(); 584 } 585 586 /// Return the indexing map to use in the expanded op for a given the 587 /// `indexingMap` of the original operation. 588 static AffineMap 589 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, 590 const ExpansionInfo &expansionInfo) { 591 SmallVector<AffineExpr> newExprs; 592 for (AffineExpr expr : indexingMap.getResults()) { 593 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 594 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>( 595 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { 596 return builder.getAffineDimExpr(static_cast<unsigned>(v)); 597 })); 598 newExprs.append(expandedExprs.begin(), expandedExprs.end()); 599 } 600 return AffineMap::get(expansionInfo.getExpandedOpNumDims(), 601 indexingMap.getNumSymbols(), newExprs, 602 builder.getContext()); 603 } 604 605 /// Return the type of the operand/result to use in the expanded op given the 606 /// type in the original op. 607 static RankedTensorType getExpandedType(RankedTensorType originalType, 608 AffineMap indexingMap, 609 const ExpansionInfo &expansionInfo) { 610 SmallVector<int64_t> expandedShape; 611 for (AffineExpr expr : indexingMap.getResults()) { 612 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 613 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); 614 expandedShape.append(dimExpansion.begin(), dimExpansion.end()); 615 } 616 return RankedTensorType::get(expandedShape, originalType.getElementType()); 617 } 618 619 /// Returns the reassociation maps to use in the `tensor.expand_shape` 620 /// operation to convert the operands of the original operation to operands of 621 /// the expanded operation. The same method is used to compute the 622 /// `tensor.collapse_shape` used to collapse the result of the expanded 623 /// op to get the value that can replace all uses of the results of the original 624 /// op. 625 static SmallVector<ReassociationIndices> 626 getReassociationForExpansion(AffineMap indexingMap, 627 const ExpansionInfo &expansionInfo) { 628 SmallVector<ReassociationIndices> reassociation; 629 unsigned numReshapeDims = 0; 630 for (AffineExpr expr : indexingMap.getResults()) { 631 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 632 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); 633 SmallVector<int64_t, 2> indices = llvm::to_vector<2>( 634 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims)); 635 reassociation.emplace_back(std::move(indices)); 636 numReshapeDims += numExpandedDims; 637 } 638 return reassociation; 639 } 640 641 /// Update the body of an expanded linalg operation having index semantics. The 642 /// indices of the original operation need to be recovered by linearizing the 643 /// indices of the correspoding dimensions of the expanded operation. For now it 644 /// is assumed that the shapes of the expanded operation needed for 645 /// linearization are static. 646 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, 647 Location loc, Region &fusedRegion, 648 const ExpansionInfo &expansionInfo) { 649 // Replace the original indices by the linearization of the expanded indices. 650 for (IndexOp indexOp : 651 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) { 652 ArrayRef<int64_t> expandedDims = 653 expansionInfo.getExpandedDims(indexOp.dim()); 654 assert(!expandedDims.empty() && "expected valid expansion info"); 655 656 // Skip index operations that are not affected by the expansion. 657 if (expandedDims.size() == 1 && 658 expandedDims.front() == (int64_t)indexOp.dim()) 659 continue; 660 661 // Linearize the expanded indices of the original index dimension. 662 OpBuilder::InsertionGuard guard(rewriter); 663 rewriter.setInsertionPointAfter(indexOp); 664 ArrayRef<int64_t> expandedDimsShape = 665 expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front(); 666 SmallVector<Value> expandedIndices; 667 expandedIndices.reserve(expandedDims.size() - 1); 668 llvm::transform( 669 expandedDims.drop_front(), std::back_inserter(expandedIndices), 670 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); 671 Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front()); 672 for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { 673 assert(!ShapedType::isDynamic(std::get<0>(it))); 674 AffineExpr idx, acc; 675 bindDims(rewriter.getContext(), idx, acc); 676 newIndex = rewriter.create<AffineApplyOp>( 677 indexOp.getLoc(), idx + acc * std::get<0>(it), 678 ValueRange{std::get<1>(it), newIndex}); 679 } 680 rewriter.replaceOp(indexOp, newIndex); 681 } 682 } 683 684 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op 685 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes 686 /// that those conditions have been satisfied. 687 static Optional<SmallVector<Value>> 688 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, 689 OpOperand *fusableOpOperand, 690 PatternRewriter &rewriter) { 691 assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) && 692 "preconditions for fuse operation failed"); 693 // Check if reshape is expanding or collapsing. 694 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp); 695 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp); 696 bool isExpanding = (expandingReshapeOp != nullptr); 697 RankedTensorType expandedType = isExpanding 698 ? expandingReshapeOp.getResultType() 699 : collapsingReshapeOp.getSrcType(); 700 RankedTensorType collapsedType = isExpanding 701 ? expandingReshapeOp.getSrcType() 702 : collapsingReshapeOp.getResultType(); 703 704 ExpansionInfo expansionInfo; 705 if (failed(expansionInfo.compute( 706 genericOp, fusableOpOperand, 707 isExpanding ? expandingReshapeOp.getReassociationMaps() 708 : collapsingReshapeOp.getReassociationMaps(), 709 expandedType.getShape(), collapsedType.getShape(), rewriter))) 710 return llvm::None; 711 712 if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) 713 return llvm::None; 714 715 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( 716 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) { 717 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); 718 })); 719 720 SmallVector<Value> expandedOpOperands; 721 expandedOpOperands.reserve(genericOp.getNumInputs()); 722 for (OpOperand *opOperand : genericOp.getInputOperands()) { 723 if (opOperand == fusableOpOperand) { 724 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() 725 : collapsingReshapeOp.src()); 726 continue; 727 } 728 if (genericOp.isInputTensor(opOperand)) { 729 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 730 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 731 RankedTensorType expandedOperandType = 732 getExpandedType(opOperandType, indexingMap, expansionInfo); 733 if (expandedOperandType != opOperand->get().getType()) { 734 // Reshape the operand to get the right type. 735 SmallVector<ReassociationIndices> reassociation = 736 getReassociationForExpansion(indexingMap, expansionInfo); 737 if (failed(reshapeLikeShapesAreCompatible( 738 [&](const Twine &msg) { 739 return rewriter.notifyMatchFailure(genericOp, msg); 740 }, 741 opOperandType.getShape(), expandedOperandType.getShape(), 742 reassociation, 743 /*isExpandingReshape=*/true))) 744 return llvm::None; 745 expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>( 746 genericOp.getLoc(), expandedOperandType, opOperand->get(), 747 reassociation)); 748 continue; 749 } 750 } 751 expandedOpOperands.push_back(opOperand->get()); 752 } 753 754 Location loc = genericOp.getLoc(); 755 SmallVector<Value> outputs; 756 for (OpOperand *opOperand : genericOp.getOutputOperands()) { 757 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 758 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 759 RankedTensorType expandedOutputType = 760 getExpandedType(opOperandType, indexingMap, expansionInfo); 761 if (expandedOutputType != opOperand->get().getType()) { 762 SmallVector<ReassociationIndices> reassociation = 763 getReassociationForExpansion(indexingMap, expansionInfo); 764 if (failed(reshapeLikeShapesAreCompatible( 765 [&](const Twine &msg) { 766 return rewriter.notifyMatchFailure(genericOp, msg); 767 }, 768 opOperandType.getShape(), expandedOutputType.getShape(), 769 reassociation, 770 /*isExpandingReshape=*/true))) 771 return llvm::None; 772 outputs.push_back(rewriter.create<tensor::ExpandShapeOp>( 773 genericOp.getLoc(), expandedOutputType, opOperand->get(), 774 reassociation)); 775 } 776 } 777 778 // The iterator types of the expanded op are all parallel. 779 SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(), 780 getParallelIteratorTypeName()); 781 782 TypeRange resultTypes = ValueRange(outputs).getTypes(); 783 auto fusedOp = 784 rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes, 785 /*inputs=*/expandedOpOperands, outputs, 786 expandedOpIndexingMaps, iteratorTypes); 787 Region &fusedRegion = fusedOp->getRegion(0); 788 Region &originalRegion = genericOp->getRegion(0); 789 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); 790 791 // Update the index accesses after the expansion. 792 updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); 793 794 // Reshape the result values to their original shape if this is a collapsing 795 // reshape folded into its consumer. 796 SmallVector<Value> resultVals; 797 for (OpResult opResult : genericOp->getOpResults()) { 798 int64_t resultNumber = opResult.getResultNumber(); 799 if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { 800 SmallVector<ReassociationIndices> reassociation = 801 getReassociationForExpansion( 802 genericOp.getTiedIndexingMap( 803 genericOp.getOutputOperand(resultNumber)), 804 expansionInfo); 805 resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>( 806 genericOp.getLoc(), opResult.getType(), 807 fusedOp->getResult(resultNumber), reassociation)); 808 } else { 809 resultVals.push_back(fusedOp->getResult(resultNumber)); 810 } 811 } 812 // Assuming a single result. 813 return resultVals; 814 } 815 816 namespace { 817 818 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op, 819 /// when the reshape op is collapsing dimensions. The dimensionality of the loop 820 /// in the consumer is expanded. 821 class FoldWithProducerReshapeOpByExpansion 822 : public OpRewritePattern<GenericOp> { 823 public: 824 FoldWithProducerReshapeOpByExpansion(MLIRContext *context, 825 ControlFusionFn foldReshapes, 826 PatternBenefit benefit = 1) 827 : OpRewritePattern<GenericOp>(context, benefit), 828 controlFoldingReshapes(std::move(foldReshapes)) {} 829 830 LogicalResult matchAndRewrite(GenericOp genericOp, 831 PatternRewriter &rewriter) const override { 832 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 833 tensor::CollapseShapeOp reshapeOp = 834 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>(); 835 if (!reshapeOp) 836 continue; 837 // Fold only if 838 // - The tensor reshape op is folding. 839 // - All constraints of fusing with reshape by expansion are met. 840 if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || 841 (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) 842 continue; 843 844 Optional<SmallVector<Value>> replacementValues = 845 fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); 846 if (!replacementValues) 847 return failure(); 848 rewriter.replaceOp(genericOp, replacementValues.getValue()); 849 return success(); 850 } 851 return failure(); 852 } 853 854 private: 855 ControlFusionFn controlFoldingReshapes; 856 }; 857 858 /// Pattern to fold a tensor_expand_shape op with its producer generic op 859 /// by expanding the dimensionality of the loop in the producer op. 860 struct FoldReshapeWithGenericOpByExpansion 861 : public OpRewritePattern<tensor::ExpandShapeOp> { 862 863 FoldReshapeWithGenericOpByExpansion(MLIRContext *context, 864 ControlFusionFn foldReshapes, 865 PatternBenefit benefit = 1) 866 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), 867 controlFoldingReshapes(std::move(foldReshapes)) {} 868 869 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, 870 PatternRewriter &rewriter) const override { 871 // Fold only if all constraints of fusing with reshape by expansion are met. 872 GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>(); 873 if (!producer || producer.getNumOutputs() != 1 || 874 !isFusableWithReshapeByDimExpansion(producer, 875 producer.getOutputOperand(0)) || 876 !controlFoldingReshapes(producer->getResult(0), 877 reshapeOp->getOpOperand(0))) 878 return failure(); 879 Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion( 880 producer, reshapeOp, producer.getOutputOperand(0), rewriter); 881 if (!replacementValues) 882 return failure(); 883 rewriter.replaceOp(reshapeOp, replacementValues.getValue()); 884 return success(); 885 } 886 887 private: 888 ControlFusionFn controlFoldingReshapes; 889 }; 890 } // namespace 891 892 //===---------------------------------------------------------------------===// 893 // Methods and patterns to fuse reshape with linalg.generic operations by 894 // contraction of dimensions. 895 //===---------------------------------------------------------------------===// 896 897 /// For a given list of indices in the range of the `indexingMap` that are 898 /// folded, return the indices of the corresponding domain. Return `llvm::None` 899 /// on failure. Ensures that all the elements of the returned reassociation are 900 /// distinct. 901 static ReassociationIndices 902 getDomainReassociation(AffineMap indexingMap, 903 ReassociationIndicesRef rangeReassociation) { 904 assert(indexingMap.isProjectedPermutation() && 905 "expected projected permutation"); 906 907 ReassociationIndices domainReassociation = llvm::to_vector<4>( 908 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t { 909 return indexingMap.getResults()[pos] 910 .cast<AffineDimExpr>() 911 .getPosition(); 912 })); 913 // The projected permutation semantics ensures that there is no repetition of 914 // the domain indices. 915 return domainReassociation; 916 } 917 918 /// For a given `dimSequence`, check if the sequence is conserved in the 919 /// `indexingMap`. `indexingMap` is expected to be a projected permutation. 920 /// Non-existence of the sequence returns true as well. 921 static bool isDimSequencePreserved(AffineMap indexingMap, 922 ReassociationIndicesRef dimSequence) { 923 assert(!dimSequence.empty() && 924 "expected non-empty list for dimension sequence"); 925 assert(indexingMap.isProjectedPermutation() && 926 "expected indexing map to be projected permutation"); 927 928 llvm::SmallDenseSet<unsigned, 4> sequenceElements; 929 sequenceElements.insert(dimSequence.begin(), dimSequence.end()); 930 931 unsigned dimSequenceStart = dimSequence[0]; 932 for (const auto &expr : enumerate(indexingMap.getResults())) { 933 unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition(); 934 // 1. Check if this start of the sequence. 935 if (dimInMapStart == dimSequenceStart) { 936 if (expr.index() + dimSequence.size() > indexingMap.getNumResults()) 937 return false; 938 // 1a. Check if sequence is preserved. 939 for (const auto &dimInSequence : enumerate(dimSequence)) { 940 unsigned dimInMap = 941 indexingMap.getResult(expr.index() + dimInSequence.index()) 942 .cast<AffineDimExpr>() 943 .getPosition(); 944 if (dimInMap != dimInSequence.value()) 945 return false; 946 } 947 // Found the sequence. Projected permutation 948 // enforces that all AffineDimExprs in the result are unique, so no 949 // further checks are needed. 950 return true; 951 } 952 // 2. If position in the expr (which is of type AffineDimExpr) is part 953 // of sequence, return false here. This implies the entire sequence does not 954 // exist in the indexing map. 955 if (sequenceElements.count(dimInMapStart)) 956 return false; 957 } 958 // 3. No element of sequence found. Return true. 959 return true; 960 } 961 962 // Return the list of dimensions of the iteration domain that can be 963 // collapsed to allow for fusion with the a producer that is an expand_shape 964 // operation. If all dimensions created by expansion can be collapsed in the 965 // iteration space then the reshape is defunct. 966 // 967 // Example: 968 // 969 // ```mlir 970 // #map = affine_map<(d0, d1) -> (d0, d1)> 971 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 972 // %2 = linalg.init_tensor [..] : tensor<?x4xf32> 973 // %3 = linalg.generic { 974 // indexing_maps = [#map, #map], 975 // iterator_types = ["parallel" ,"parallel"]} 976 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. } 977 // ``` 978 // 979 // can be fused by collapsing the dimensions of the iteration space. 980 // 981 // ```mlir 982 // #map = affine_map<(d0) -> (d0)> 983 // %2 = linalg.init_tensor [..] : tensor<?xf32> 984 // %3 = linalg.generic { 985 // indexing_maps = [#map, #map], 986 // iterator_types = ["parallel"]} 987 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. } 988 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 989 // ``` 990 // 991 // In the following example, 992 // 993 // ```mlir 994 // #map0 = affine_map<(d0, d1) -> (d0, d1)> 995 // #map1 = affine_map<(d0, d1) -> (d1, d0)> 996 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 997 // %2 = linalg.init_tensor [..] : tensor<4x?xf32> 998 // %2 = linalg.generic { 999 // indexing_maps = [#map0, #map1], 1000 // iterator_types = ["parallel" ,"parallel"]} 1001 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. } 1002 // ``` 1003 // 1004 // the reshape cannot be fused with the generic op by collapsing the op 1005 // dimensions since the indexing maps will have to contain mods and divs 1006 // to preserve the accesses pattern. When no dimensions of the iteration 1007 // space are collapsable and empty vector is returned. 1008 static SmallVector<ReassociationIndices> 1009 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, 1010 ArrayRef<ReassociationIndices> reassociation) { 1011 // Some basic checks for this fusion to be valid. 1012 if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1) 1013 return {}; 1014 1015 if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) { 1016 return map.isProjectedPermutation(); 1017 })) { 1018 return {}; 1019 } 1020 1021 // Compute all the loops with the reduction iterator types. 1022 SmallVector<int64_t> reductionDims; 1023 for (const auto &iteratorType : llvm::enumerate(genericOp.iterator_types())) { 1024 if (isReductionIterator(iteratorType.value())) { 1025 reductionDims.push_back(iteratorType.index()); 1026 } 1027 } 1028 1029 llvm::SmallDenseSet<unsigned, 4> processedIterationDims; 1030 AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand); 1031 auto iteratorTypes = genericOp.iterator_types().getValue(); 1032 SmallVector<ReassociationIndices> iterationSpaceReassociation; 1033 for (ReassociationIndicesRef foldedRangeDims : reassociation) { 1034 assert(!foldedRangeDims.empty() && "unexpected empty reassociation"); 1035 1036 // Ignore dims that are not folded. 1037 if (foldedRangeDims.size() == 1) 1038 continue; 1039 1040 ReassociationIndices foldedIterationSpaceDims = 1041 getDomainReassociation(indexingMap, foldedRangeDims); 1042 1043 // Check that the folded iteration dims do not contain already processed 1044 // dims. 1045 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) { 1046 return processedIterationDims.count(dim); 1047 })) 1048 continue; 1049 1050 // Check that all folded iterator types are all parallel or all reductions. 1051 Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]]; 1052 if (!isParallelIterator(startIteratorType) && 1053 !isReductionIterator(startIteratorType)) 1054 continue; 1055 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) { 1056 return iteratorTypes[dim] != startIteratorType; 1057 })) 1058 continue; 1059 1060 // If the folded dimensions correspond to a "reduction" iterator type, 1061 // the folded dimensions need to be "in-order". Strictly speaking this is 1062 // not necessary, for reductions that are associative and commutative, but 1063 // using a more strict definition of reduction for now. 1064 if (isReductionIterator(startIteratorType)) { 1065 bool isContiguous = false; 1066 for (const auto &startDim : llvm::enumerate(reductionDims)) { 1067 // Move window in `reductionDims` to start of the folded iteration dims. 1068 if (startDim.value() != foldedIterationSpaceDims[0]) 1069 continue; 1070 // If sizes doesnt match, trivial not contiguous. This condition should 1071 // not be hit. 1072 if (startDim.index() + foldedIterationSpaceDims.size() > 1073 reductionDims.size()) 1074 break; 1075 // Check that the contiguity is maintained. 1076 isContiguous = true; 1077 for (const auto &foldedDim : 1078 llvm::enumerate(foldedIterationSpaceDims)) { 1079 if (reductionDims[foldedDim.index() + startDim.index()] != 1080 foldedDim.value()) { 1081 isContiguous = false; 1082 break; 1083 } 1084 } 1085 break; 1086 } 1087 if (!isContiguous) 1088 continue; 1089 } 1090 1091 // Check that the sequence is preserved in all indexing maps. 1092 if (llvm::any_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) { 1093 return !isDimSequencePreserved(indexingMap, foldedIterationSpaceDims); 1094 })) 1095 continue; 1096 1097 processedIterationDims.insert(foldedIterationSpaceDims.begin(), 1098 foldedIterationSpaceDims.end()); 1099 iterationSpaceReassociation.emplace_back( 1100 std::move(foldedIterationSpaceDims)); 1101 } 1102 1103 return iterationSpaceReassociation; 1104 } 1105 1106 /// Helper class to carry state while collapsing the `linalg.generic` op. 1107 namespace { 1108 class CollapsingInfo { 1109 public: 1110 LogicalResult initialize(unsigned origNumLoops, 1111 ArrayRef<ReassociationIndices> foldedIterationDims) { 1112 llvm::SmallDenseSet<int64_t, 4> processedDims; 1113 // Find all the dims that are folded. 1114 for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) { 1115 if (foldedIterationDim.empty()) 1116 continue; 1117 // If the folded dims contain dims already folded, that's illegal 1118 // specification. Repetition within a list is also illegal. 1119 for (auto dim : foldedIterationDim) { 1120 if (dim >= origNumLoops) 1121 return failure(); 1122 if (processedDims.count(dim)) 1123 return failure(); 1124 processedDims.insert(dim); 1125 } 1126 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(), 1127 foldedIterationDim.end()); 1128 } 1129 if (processedDims.size() > origNumLoops) 1130 return failure(); 1131 1132 // Add all the preserved dims of the original op as single 1133 // elements to `collapsedOpToOrigOpIterationDim`. 1134 for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) { 1135 if (processedDims.count(dim)) 1136 continue; 1137 collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim}); 1138 } 1139 1140 llvm::sort(collapsedOpToOrigOpIterationDim, 1141 [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) { 1142 return lhs[0] < rhs[0]; 1143 }); 1144 origOpToCollapsedOpIterationDim.resize(origNumLoops); 1145 for (const auto &foldedDims : 1146 llvm::enumerate(collapsedOpToOrigOpIterationDim)) { 1147 for (const auto &dim : enumerate(foldedDims.value())) 1148 origOpToCollapsedOpIterationDim[dim.value()] = 1149 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index()); 1150 } 1151 return success(); 1152 } 1153 1154 /// Return mapping from collapsed loop domain to original loop domain. 1155 ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const { 1156 return collapsedOpToOrigOpIterationDim; 1157 } 1158 1159 /// Return mapping from original loop domain to collapsed loop domain. The 1160 /// mapping is a pair. First value is the dimension in the collapsed loop that 1161 /// the original loop is mapped to. Second is the relative position in folded 1162 /// list of this domain. For example if the original loop domain is 3D, and 1163 /// the collapsed loop domain is folding all of it, i.e. 1164 /// 1165 /// ``` 1166 /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]` 1167 /// ``` 1168 /// 1169 /// then 1170 /// 1171 /// ``` 1172 /// origOpToCollapsedOpMapping[0] = {0, 0}; 1173 /// origOpToCollapsedOpMapping[1] = {0, 1}; 1174 /// origOpToCollapsedOpMapping[2] = {0, 2}; 1175 /// origOpToCollapsedOpMapping[3] = {1, 0}; 1176 /// origOpToCollapsedOpMapping[4] = {1, 1}; 1177 /// ``` 1178 /// 1179 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const { 1180 return origOpToCollapsedOpIterationDim; 1181 } 1182 1183 /// Return the collapsed op iteration domain rank. 1184 unsigned getCollapsedOpIterationRank() const { 1185 return collapsedOpToOrigOpIterationDim.size(); 1186 } 1187 1188 private: 1189 /// Map from the iteration domain index in collapsed op to the iteration 1190 /// domain indices in the original op. 1191 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim; 1192 1193 /// Map from iteration domain index in the original op to the iteration domain 1194 /// index in the collapsed op. 1195 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim; 1196 }; 1197 } // namespace 1198 1199 /// Get the iterator types for the collapsed operation given the original 1200 /// iterator types and collapsed dimensions. 1201 static SmallVector<StringRef> 1202 getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes, 1203 const CollapsingInfo &collapsingInfo) { 1204 SmallVector<StringRef> collapsedIteratorTypes; 1205 for (ReassociationIndicesRef foldedIterDims : 1206 collapsingInfo.getCollapsedOpToOrigOpMapping()) { 1207 assert(!foldedIterDims.empty() && 1208 "reassociation indices expected to have non-empty sets"); 1209 // Just pick the iterator type of the first folded dim. Pre-condition checks 1210 // expected to have checked that iterator types of all folded dimensions are 1211 // the same. 1212 collapsedIteratorTypes.push_back( 1213 iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue()); 1214 } 1215 return collapsedIteratorTypes; 1216 } 1217 1218 /// Compute the indexing map in the collapsed op that corresponds to the given 1219 /// `indexingMap` of the original operation. 1220 static AffineMap 1221 getCollapsedOpIndexingMap(AffineMap indexingMap, 1222 const CollapsingInfo &collapsingInfo) { 1223 MLIRContext *context = indexingMap.getContext(); 1224 assert(indexingMap.isProjectedPermutation() && 1225 "expected indexing map to be projected permutation"); 1226 SmallVector<AffineExpr> resultExprs; 1227 auto origOpToCollapsedOpMapping = 1228 collapsingInfo.getOrigOpToCollapsedOpMapping(); 1229 for (auto expr : indexingMap.getResults()) { 1230 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 1231 // If the dim is not the first of the collapsed dim, do nothing. 1232 if (origOpToCollapsedOpMapping[dim].second != 0) 1233 continue; 1234 // The next n-dims are guaranteed to be collapsed. So just use the 1235 // iteration dimension of the collapsed op. 1236 resultExprs.push_back( 1237 getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context)); 1238 } 1239 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0, 1240 resultExprs, context); 1241 } 1242 1243 /// Return the `reassociation` indices to use to collapse the operand when the 1244 /// iteration space of a generic op is collapsed. 1245 static SmallVector<ReassociationIndices> 1246 getOperandReassociation(AffineMap indexingMap, 1247 const CollapsingInfo &collapsingInfo) { 1248 unsigned counter = 0; 1249 SmallVector<ReassociationIndices> operandReassociation; 1250 auto origOpToCollapsedOpMapping = 1251 collapsingInfo.getOrigOpToCollapsedOpMapping(); 1252 auto collapsedOpToOrigOpMapping = 1253 collapsingInfo.getCollapsedOpToOrigOpMapping(); 1254 while (counter < indexingMap.getNumResults()) { 1255 unsigned dim = 1256 indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition(); 1257 if (origOpToCollapsedOpMapping[dim].second == 0) { 1258 // This is the start of a collapsed dimensions of the iteration that 1259 // is gauranteed to be preserved in the indexing map. The number of folded 1260 // dims is obtained from the collapsed op to original op mapping. 1261 unsigned numFoldedDims = 1262 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first] 1263 .size(); 1264 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims); 1265 operandReassociation.emplace_back(range.begin(), range.end()); 1266 counter += numFoldedDims; 1267 } 1268 } 1269 return operandReassociation; 1270 } 1271 1272 /// Get the new value to use for a given `OpOperand` in the collapsed operation. 1273 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, 1274 OpOperand *opOperand, 1275 const CollapsingInfo &collapsingInfo, 1276 OpBuilder &builder) { 1277 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 1278 SmallVector<ReassociationIndices> operandReassociation = 1279 getOperandReassociation(indexingMap, collapsingInfo); 1280 1281 // If the number of entries in the reassocation for the operand is same as the 1282 // number of results of the indexing map, then nothing to do for this operand. 1283 Value operand = opOperand->get(); 1284 if (operandReassociation.size() == indexingMap.getNumResults()) 1285 return operand; 1286 1287 // Insert a reshape to collapse the dimensions. 1288 auto reshapeOp = builder.create<tensor::CollapseShapeOp>( 1289 loc, operand, operandReassociation); 1290 return reshapeOp.getResult(); 1291 } 1292 1293 /// Modify the `linalg.index` operations in the original generic op, to its 1294 /// value in the collapsed operation. 1295 void generateCollapsedIndexingRegion(Location loc, Block *block, 1296 const CollapsingInfo &collapsingInfo, 1297 ValueRange loopRange, 1298 PatternRewriter &rewriter) { 1299 OpBuilder::InsertionGuard g(rewriter); 1300 rewriter.setInsertionPointToStart(block); 1301 1302 // Collect all the original index ops. 1303 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>()); 1304 1305 // For each folded dimension list resolve the original induction variable 1306 // values in terms of the folded dimension induction variable. 1307 // i_{folded} = (i_0 * d1 + i1) * d2 + i2. 1308 // can be inverted to 1309 // i2 = i_{folded} % d2 1310 // i1 = (i_{folded} / d2) % d1 1311 // i0 = i_{folded} / (d1 * d2) 1312 llvm::DenseMap<unsigned, Value> indexReplacementVals; 1313 for (auto &foldedDims : 1314 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) { 1315 ReassociationIndicesRef foldedDimsRef(foldedDims.value()); 1316 Value newIndexVal = 1317 rewriter.create<linalg::IndexOp>(loc, foldedDims.index()); 1318 for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) { 1319 indexReplacementVals[dim] = 1320 rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]); 1321 newIndexVal = 1322 rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]); 1323 } 1324 indexReplacementVals[foldedDims.value().front()] = newIndexVal; 1325 } 1326 1327 for (auto indexOp : indexOps) { 1328 auto dim = indexOp.dim(); 1329 rewriter.replaceOp(indexOp, indexReplacementVals[dim]); 1330 } 1331 } 1332 1333 /// Implementation of fusion with reshape operation by collapsing dimensions. 1334 static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims( 1335 GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims, 1336 OpOperand *fusableOpOperand, PatternRewriter &rewriter) { 1337 // Bail on trivial no-op cases. 1338 if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() || 1339 llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { 1340 return foldedDims.size() <= 1; 1341 })) 1342 return failure(); 1343 1344 CollapsingInfo collapsingInfo; 1345 if (failed(collapsingInfo.initialize(genericOp.getNumLoops(), 1346 foldedIterationDims))) { 1347 return rewriter.notifyMatchFailure( 1348 genericOp, "illegal to collapse specified dimensions"); 1349 } 1350 1351 // Get the iterator types for the operand. 1352 SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes( 1353 genericOp.iterator_types().getValue(), collapsingInfo); 1354 1355 // Get the indexing maps. 1356 auto indexingMaps = llvm::to_vector( 1357 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap map) { 1358 return getCollapsedOpIndexingMap(map, collapsingInfo); 1359 })); 1360 1361 Location loc = genericOp->getLoc(); 1362 1363 // Get the input operands. 1364 auto inputOperands = llvm::to_vector( 1365 llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) { 1366 return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo, 1367 rewriter); 1368 })); 1369 1370 // Get the output operands and result types. 1371 SmallVector<Type> resultTypes; 1372 SmallVector<Value> outputOperands; 1373 resultTypes.reserve(genericOp.getNumOutputs()); 1374 outputOperands.reserve(genericOp.getNumOutputs()); 1375 for (OpOperand *output : genericOp.getOutputOperands()) { 1376 Value newOutput = 1377 getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter); 1378 outputOperands.push_back(newOutput); 1379 resultTypes.push_back(newOutput.getType()); 1380 } 1381 1382 // Create the generic op. 1383 auto collapsedGenericOp = rewriter.create<linalg::GenericOp>( 1384 loc, resultTypes, inputOperands, outputOperands, indexingMaps, 1385 iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); 1386 Block *origOpBlock = &genericOp->getRegion(0).front(); 1387 Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front(); 1388 rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, 1389 collapsedOpBlock->getArguments()); 1390 1391 if (collapsedGenericOp.hasIndexSemantics()) { 1392 // Collect the loop range of the generic op. 1393 OpBuilder::InsertionGuard g(rewriter); 1394 rewriter.setInsertionPoint(collapsedGenericOp); 1395 SmallVector<Range> loopRanges = 1396 cast<LinalgOp>(genericOp.getOperation()) 1397 .createLoopRanges(rewriter, genericOp.getLoc()); 1398 assert(llvm::all_of(loopRanges, 1399 [](Range range) { 1400 return matchPattern(range.offset, m_Zero()) && 1401 matchPattern(range.stride, m_One()); 1402 }) && 1403 "expected all loop ranges to have zero start and unit stride"); 1404 SmallVector<Value> loopBound = llvm::to_vector( 1405 llvm::map_range(loopRanges, [](Range range) { return range.size; })); 1406 generateCollapsedIndexingRegion(loc, 1407 &collapsedGenericOp->getRegion(0).front(), 1408 collapsingInfo, loopBound, rewriter); 1409 } 1410 1411 // Insert expanding reshape for the result to get back the original result 1412 // type. 1413 SmallVector<Value> results; 1414 for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) { 1415 Value collapsedOpResult = 1416 collapsedGenericOp->getResult(originalResult.index()); 1417 auto originalResultType = 1418 originalResult.value().getType().cast<ShapedType>(); 1419 auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>(); 1420 if (collapsedOpResultType.getRank() != originalResultType.getRank()) { 1421 AffineMap indexingMap = 1422 genericOp.getTiedIndexingMapForResult(originalResult.value()); 1423 SmallVector<ReassociationIndices> reassociation = 1424 getOperandReassociation(indexingMap, collapsingInfo); 1425 Value result = rewriter.create<tensor::ExpandShapeOp>( 1426 loc, originalResultType, collapsedOpResult, reassociation); 1427 results.push_back(result); 1428 } else { 1429 results.push_back(collapsedOpResult); 1430 } 1431 } 1432 return results; 1433 } 1434 1435 namespace { 1436 1437 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by 1438 /// contracting dimensions of the loop. 1439 class FoldWithProducerReshapeOpByCollapsing 1440 : public OpRewritePattern<GenericOp> { 1441 public: 1442 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context, 1443 ControlFusionFn foldReshapes, 1444 PatternBenefit benefit = 1) 1445 : OpRewritePattern<GenericOp>(context, benefit), 1446 controlFoldingReshapes(std::move(foldReshapes)) {} 1447 1448 LogicalResult matchAndRewrite(GenericOp genericOp, 1449 PatternRewriter &rewriter) const override { 1450 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 1451 tensor::ExpandShapeOp reshapeOp = 1452 opOperand->get().getDefiningOp<tensor::ExpandShapeOp>(); 1453 if (!reshapeOp) 1454 continue; 1455 1456 SmallVector<ReassociationIndices> collapsableIterationDims = 1457 getCollapsableIterationSpaceDims(genericOp, opOperand, 1458 reshapeOp.getReassociationIndices()); 1459 if (collapsableIterationDims.empty() || 1460 !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) { 1461 continue; 1462 } 1463 1464 Optional<SmallVector<Value>> replacements = 1465 collapseGenericOpIterationDims(genericOp, collapsableIterationDims, 1466 opOperand, rewriter); 1467 if (!replacements) { 1468 return rewriter.notifyMatchFailure( 1469 genericOp, "failed to do the fusion by collapsing transformation"); 1470 } 1471 1472 rewriter.replaceOp(genericOp, replacements.getValue()); 1473 return success(); 1474 } 1475 return failure(); 1476 } 1477 1478 private: 1479 ControlFusionFn controlFoldingReshapes; 1480 }; 1481 } // namespace 1482 1483 //===---------------------------------------------------------------------===// 1484 // Methods and patterns that fuse constants with linalg.generic operations. 1485 //===---------------------------------------------------------------------===// 1486 1487 namespace { 1488 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not 1489 /// handle cases where the constant is not single-valued. 1490 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> { 1491 public: 1492 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1) 1493 : OpRewritePattern<GenericOp>(context, benefit) {} 1494 1495 LogicalResult matchAndRewrite(GenericOp genericOp, 1496 PatternRewriter &rewriter) const override { 1497 if (!genericOp.hasTensorSemantics()) 1498 return failure(); 1499 for (OpOperand *opOperand : genericOp.getInputOperands()) { 1500 Operation *def = opOperand->get().getDefiningOp(); 1501 Attribute constantAttr; 1502 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { 1503 { 1504 DenseElementsAttr splatAttr; 1505 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) && 1506 splatAttr.isSplat() && 1507 splatAttr.getType().getElementType().isIntOrFloat()) { 1508 constantAttr = splatAttr.getSplatValue<Attribute>(); 1509 return true; 1510 } 1511 } 1512 { 1513 IntegerAttr intAttr; 1514 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) { 1515 constantAttr = intAttr; 1516 return true; 1517 } 1518 } 1519 { 1520 FloatAttr floatAttr; 1521 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) { 1522 constantAttr = floatAttr; 1523 return true; 1524 } 1525 } 1526 return false; 1527 }; 1528 1529 auto resultValue = opOperand->get().dyn_cast<OpResult>(); 1530 if (!def || !resultValue || !isScalarOrSplatConstantOp(def)) 1531 continue; 1532 1533 // The operands and the indexing_maps of the fused operation the same as 1534 // the operands and indexing_maps of the generic operations with the 1535 // values at the constant index dropped. 1536 SmallVector<AffineMap> fusedIndexMaps; 1537 SmallVector<Value> fusedOperands; 1538 SmallVector<Location> fusedLocs{genericOp.getLoc()}; 1539 fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); 1540 fusedOperands.reserve(genericOp.getNumInputs()); 1541 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); 1542 for (OpOperand *inputOperand : genericOp.getInputOperands()) { 1543 if (inputOperand == opOperand) 1544 continue; 1545 Value inputValue = inputOperand->get(); 1546 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); 1547 fusedOperands.push_back(inputValue); 1548 fusedLocs.push_back(inputValue.getLoc()); 1549 } 1550 for (OpOperand *outputOperand : genericOp.getOutputOperands()) 1551 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); 1552 1553 // Check if the operation shapes to loops map is computable. 1554 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1555 return rewriter.notifyMatchFailure( 1556 genericOp, "fused op loop bound computation failed"); 1557 } 1558 1559 // Create a constant scalar value from the splat constant. 1560 Value scalarConstant = rewriter.create<arith::ConstantOp>( 1561 def->getLoc(), constantAttr, constantAttr.getType()); 1562 1563 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 1564 auto fusedOp = rewriter.create<GenericOp>( 1565 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), 1566 /*inputs=*/fusedOperands, 1567 /*outputs=*/outputOperands, 1568 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1569 genericOp.iterator_types(), 1570 /*doc=*/nullptr, 1571 /*library_call=*/nullptr); 1572 1573 // Map the block argument corresponding to the replaced argument with the 1574 // scalar constant. 1575 Region ®ion = genericOp->getRegion(0); 1576 Block &entryBlock = *region.begin(); 1577 BlockAndValueMapping mapping; 1578 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), 1579 scalarConstant); 1580 Region &fusedRegion = fusedOp->getRegion(0); 1581 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), 1582 mapping); 1583 rewriter.replaceOp(genericOp, fusedOp->getResults()); 1584 return success(); 1585 } 1586 return failure(); 1587 } 1588 }; 1589 1590 } // namespace 1591 1592 //===---------------------------------------------------------------------===// 1593 // Miscellaneous patterns that help fusion. 1594 //===---------------------------------------------------------------------===// 1595 1596 namespace { 1597 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if 1598 /// the value of the `outs` operand is not used within the op. This is only 1599 /// implemented for `linalg.generic` operations for now, but should hold for all 1600 /// linalg structured ops. 1601 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> { 1602 using OpRewritePattern<GenericOp>::OpRewritePattern; 1603 1604 LogicalResult matchAndRewrite(GenericOp op, 1605 PatternRewriter &rewriter) const override { 1606 rewriter.startRootUpdate(op); 1607 bool modifiedOutput = false; 1608 Location loc = op.getLoc(); 1609 for (OpOperand *opOperand : op.getOutputOperands()) { 1610 if (!op.payloadUsesValueFromOperand(opOperand)) { 1611 Value operandVal = opOperand->get(); 1612 auto operandType = operandVal.getType().dyn_cast<RankedTensorType>(); 1613 if (!operandType) 1614 continue; 1615 1616 // If outs is sparse, leave it to the sparse compiler. 1617 if (sparse_tensor::getSparseTensorEncoding(operandVal.getType())) 1618 continue; 1619 1620 // If outs is already an `init_tensor` operation, nothing to do. 1621 auto definingOp = operandVal.getDefiningOp<InitTensorOp>(); 1622 if (definingOp) 1623 continue; 1624 modifiedOutput = true; 1625 SmallVector<Value> dynamicDims; 1626 for (const auto &dim : llvm::enumerate(operandType.getShape())) { 1627 if (dim.value() != ShapedType::kDynamicSize) 1628 continue; 1629 dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>( 1630 loc, operandVal, dim.index())); 1631 } 1632 Value initTensor = rewriter.create<InitTensorOp>( 1633 loc, dynamicDims, operandType.getShape(), 1634 operandType.getElementType()); 1635 op->setOperand(opOperand->getOperandNumber(), initTensor); 1636 } 1637 } 1638 if (!modifiedOutput) { 1639 rewriter.cancelRootUpdate(op); 1640 return failure(); 1641 } 1642 rewriter.finalizeRootUpdate(op); 1643 return success(); 1644 } 1645 }; 1646 1647 /// Fold linalg.fill into linalg.generic 1648 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> { 1649 using OpRewritePattern<GenericOp>::OpRewritePattern; 1650 1651 LogicalResult matchAndRewrite(GenericOp genericOp, 1652 PatternRewriter &rewriter) const override { 1653 if (!genericOp.hasTensorSemantics()) 1654 return failure(); 1655 bool fillFound = false; 1656 Block &payload = genericOp.region().front(); 1657 for (OpOperand *opOperand : genericOp.getInputOperands()) { 1658 if (!genericOp.payloadUsesValueFromOperand(opOperand)) 1659 continue; 1660 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>(); 1661 if (!fillOp) 1662 continue; 1663 fillFound = true; 1664 payload.getArgument(opOperand->getOperandNumber()) 1665 .replaceAllUsesWith(fillOp.value()); 1666 } 1667 return success(fillFound); 1668 } 1669 }; 1670 } // namespace 1671 1672 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( 1673 RewritePatternSet &patterns, 1674 const ControlFusionFn &controlFoldingReshapes) { 1675 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(), 1676 controlFoldingReshapes); 1677 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), 1678 controlFoldingReshapes); 1679 } 1680 1681 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( 1682 RewritePatternSet &patterns, 1683 const ControlFusionFn &controlFoldingReshapes) { 1684 patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(), 1685 controlFoldingReshapes); 1686 } 1687 1688 void mlir::linalg::populateElementwiseOpsFusionPatterns( 1689 RewritePatternSet &patterns, 1690 const ControlFusionFn &controlElementwiseOpsFusion) { 1691 auto *context = patterns.getContext(); 1692 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion); 1693 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant, 1694 RemoveOutsDependency>(context); 1695 } 1696 1697 //===---------------------------------------------------------------------===// 1698 // Passes 1699 //===---------------------------------------------------------------------===// 1700 1701 namespace { 1702 1703 /// Pass that fuses generic ops on tensors. Used only for testing. 1704 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the 1705 // patterns added here heavily depends on the cost function used. Having an 1706 // opinionated pass of this form is not recommended. Deprecate this pass in 1707 // favor of test passes that check the functionality of each of the patterns 1708 // added here individually. 1709 struct LinalgElementwiseOpFusionPass 1710 : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> { 1711 void runOnOperation() override { 1712 Operation *op = getOperation(); 1713 MLIRContext *context = op->getContext(); 1714 RewritePatternSet patterns(context); 1715 1716 // Add folding with reshape by expansion patterns. 1717 ControlFusionFn defaultControlFn = [](const OpResult &producer, 1718 const OpOperand &consumer) { 1719 return producer.hasOneUse(); 1720 }; 1721 1722 // Add elementwise op fusion patterns. 1723 populateElementwiseOpsFusionPatterns(patterns, defaultControlFn); 1724 1725 populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn); 1726 1727 // Add the sparse tensor rewriting patterns. 1728 populateSparseTensorRewriting(patterns); 1729 1730 // General canonicalization patterns. 1731 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 1732 GenericOp::getCanonicalizationPatterns(patterns, context); 1733 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); 1734 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); 1735 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns( 1736 patterns); 1737 1738 // Add constant folding patterns. 1739 populateConstantFoldLinalgOperations(patterns, defaultControlFn); 1740 1741 // Use TopDownTraversal for compile time reasons 1742 GreedyRewriteConfig grc; 1743 grc.useTopDownTraversal = true; 1744 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), 1745 grc); 1746 } 1747 }; 1748 1749 } // namespace 1750 1751 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() { 1752 return std::make_unique<LinalgElementwiseOpFusionPass>(); 1753 } 1754