1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion on tensors operations pass. 10 // 11 //===----------------------------------------------------------------------===// 12 #include <utility> 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 28 using namespace mlir; 29 using namespace mlir::linalg; 30 31 //===---------------------------------------------------------------------===// 32 // Methods and patterns that fuse elementwise `linalg.generic` operations. 33 //===---------------------------------------------------------------------===// 34 35 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 36 /// the `producer` to use in the fused operation given the indexing map of the 37 /// result of the producer in the consumer. 38 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 39 OpOperand *producerOpOperand, AffineMap producerResultIndexMap, 40 AffineMap fusedConsumerArgIndexMap) { 41 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 42 // from consumer loop -> consumer arg tensor index/producer result tensor 43 // index. The fused loop is same as the consumer loop. For each producer arg 44 // the indexing map to be computed is a map from consumer loop -> producer 45 // arg tensor index. 46 // producerResultIndexMap is a map from producer loop -> tensor index. 47 // Compute the inverse to get map from tensor index -> producer loop. 48 // The inverse is a map from producer result tensor index -> producer loop. 49 AffineMap invProducerResultIndexMap = 50 inversePermutation(producerResultIndexMap); 51 assert(invProducerResultIndexMap && 52 "expected producer result indexing map to be invertible"); 53 54 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner()); 55 // argMap is a map from producer loop -> producer arg tensor index. 56 AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); 57 58 // Compose argMap with invProducerResultIndexMap to get a map from 59 // producer result tensor index -> producer arg tensor index. 60 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 61 62 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 63 // consumer loop/ fused loop -> producer arg tensor index. 64 return t1.compose(fusedConsumerArgIndexMap); 65 } 66 67 /// Conditions for elementwise fusion of generic operations. 68 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, 69 OpOperand *consumerOpOperand) { 70 // Producer and consumer must have tensor semantics. 71 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 72 return false; 73 74 // Verify that 75 // - the producer has all "parallel" iterator type. 76 if (producer.getNumParallelLoops() != producer.getNumLoops()) 77 return false; 78 79 // Only allow fusing the producer of an input operand for now. 80 // TODO: allow fusing the producer of an output operand. 81 if (!consumer.isInputTensor(consumerOpOperand)) 82 return false; 83 84 // Get the consumer index map. The number of results of the consumer index 85 // map must match the number of loops of the producer. 86 AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); 87 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 88 return false; 89 90 // Currently support only operations with single result. 91 if (producer.getNumOutputs() != 1) 92 return false; 93 94 // Finally the index_map for the result must be invertible. For now just 95 // verify it is a permutation. 96 AffineMap producerResultIndexMap = 97 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 98 if (!producerResultIndexMap.isPermutation()) 99 return false; 100 101 // Ensure that the fusion does not remove size information required to 102 // get the loop bounds. For non-reduction generics, this is trivially the 103 // case due to the output operand. For reductions, we need to check that after 104 // the fusion, each loop dimension has at least one input that defines it. 105 if ((consumer.getNumReductionLoops())) { 106 BitVector coveredDims(consumer.getNumLoops(), false); 107 108 auto addToCoveredDims = [&](AffineMap map) { 109 for (auto result : map.getResults()) 110 if (auto dimExpr = result.dyn_cast<AffineDimExpr>()) 111 coveredDims[dimExpr.getPosition()] = true; 112 }; 113 114 for (auto pair : 115 llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) { 116 Value operand = std::get<0>(pair); 117 if (operand == consumerOpOperand->get()) 118 continue; 119 AffineMap operandMap = std::get<1>(pair); 120 addToCoveredDims(operandMap); 121 } 122 123 for (OpOperand *operand : producer.getInputOperands()) { 124 AffineMap newIndexingMap = 125 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 126 operand, producerResultIndexMap, consumerIndexMap); 127 addToCoveredDims(newIndexingMap); 128 } 129 if (!coveredDims.all()) 130 return false; 131 } 132 133 return true; 134 } 135 136 /// Generate the region of the fused tensor operation. The region of the fused 137 /// op must be empty. 138 static void 139 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, 140 AffineMap consumerToProducerLoopsMap, 141 OpOperand *consumerOpOperand, 142 unsigned nloops) { 143 auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp()); 144 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 145 // Build the region of the fused op. 146 Block &producerBlock = producer->getRegion(0).front(); 147 Block &consumerBlock = consumer->getRegion(0).front(); 148 Block *fusedBlock = new Block(); 149 fusedOp.region().push_back(fusedBlock); 150 BlockAndValueMapping mapper; 151 OpBuilder::InsertionGuard guard(rewriter); 152 rewriter.setInsertionPointToStart(fusedBlock); 153 154 // 2. Add an index operation for every fused loop dimension and use the 155 // `consumerToProducerLoopsMap` to map the producer indices. 156 if (producer.hasIndexSemantics()) { 157 // Add an index operation for every fused loop dimension. 158 unsigned numFusedOpLoops = 159 std::max(producer.getNumLoops(), consumer.getNumLoops()); 160 SmallVector<Value> fusedIndices; 161 fusedIndices.reserve(numFusedOpLoops); 162 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops), 163 std::back_inserter(fusedIndices), [&](uint64_t dim) { 164 return rewriter.create<IndexOp>(producer.getLoc(), dim); 165 }); 166 for (IndexOp indexOp : 167 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) { 168 Value newIndex = rewriter.create<mlir::AffineApplyOp>( 169 producer.getLoc(), 170 consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices); 171 mapper.map(indexOp.getResult(), newIndex); 172 } 173 } 174 // TODO: allow fusing the producer of an output operand. 175 assert(consumer.isInputTensor(consumerOpOperand) && 176 "expected producer of input operand"); 177 // 3. Consumer input operands up to consumerIdx (exclusive). 178 for (BlockArgument bbArg : consumerBlock.getArguments().take_front( 179 consumerOpOperand->getOperandNumber())) // input assumption. 180 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 181 182 // Replacing consumerIdx requires getting the cloned, yielded, value from 183 // the (cloned) producer block. This happens in step 9. 184 185 // 4. Splice in producer's input operands. 186 for (BlockArgument bbArg : 187 producerBlock.getArguments().take_front(producer.getNumInputs())) 188 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 189 190 // 4.b. Producer output operand/map that is fused needs to be mapped to the 191 // producer bbArg if it is an "initTensor" (i.e. its value is actually read). 192 assert(producer->getNumResults() == 1 && "expected single result producer"); 193 if (producer.isInitTensor(producer.getOutputOperand(0))) { 194 BlockArgument bbArg = producerBlock.getArguments() 195 .drop_front(producer.getNumInputs()) 196 // TODO: bbArg index of 197 .front(); 198 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 199 } 200 // 5. Remaining consumer's input operands (drop past index `consumerIdx`). 201 for (BlockArgument bbArg : 202 consumerBlock.getArguments() 203 .take_front(consumer.getNumInputs()) 204 .drop_front(consumerOpOperand->getOperandNumber() + 1)) 205 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 206 // 6. All of consumer's output operands. 207 for (BlockArgument bbArg : 208 consumerBlock.getArguments().take_back(consumer.getNumOutputs())) 209 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 210 // 7. All of producer's output operands except the one fused. 211 // TODO: allow fusion of multi-result producers. 212 assert(producer->getNumResults() == 1 && "expected single result producer"); 213 214 // 8. Clone all producer operations except for the yield and index operations 215 // to the fused operation. 216 for (auto &op : producerBlock.without_terminator()) { 217 if (!isa<IndexOp>(op)) 218 rewriter.clone(op, mapper); 219 } 220 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just 221 // forward the yield operand. 222 auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator()); 223 // TODO: allow fusion of multi-result producers. 224 assert(producer->getNumResults() == 1 && "expected single result producer"); 225 unsigned producerResultNumber = 0; 226 Value replacement = 227 mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); 228 // Sanity checks, if replacement is not already in the mapper then it must be 229 // produced outside. 230 if (replacement == yieldOp.getOperand(producerResultNumber)) { 231 if (auto bb = replacement.dyn_cast<BlockArgument>()) 232 assert(bb.getOwner() != &producerBlock && 233 "yielded block argument must have been mapped"); 234 else 235 assert(!producer->isAncestor(replacement.getDefiningOp()) && 236 "yielded value must have been mapped"); 237 } 238 mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), 239 replacement); 240 // 10. Clone operations from the consumer to the fused op. 241 for (auto &op : consumerBlock.getOperations()) 242 rewriter.clone(op, mapper); 243 244 // Sanity checks. 245 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && 246 "Ill-formed GenericOp region"); 247 } 248 249 static Optional<SmallVector<Value>> 250 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, 251 const ControlFusionFn &controlFn, 252 PatternRewriter &rewriter) { 253 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 254 if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || 255 !controlFn(producer->getResult(0), *consumerOpOperand)) 256 return llvm::None; 257 258 // TODO: allow fusing the producer of an output operand. 259 assert(consumer.isInputTensor(consumerOpOperand) && 260 "expected producer of input operand"); 261 262 // Compute the fused operands list and indexing maps. 263 SmallVector<Value> fusedOperands; 264 SmallVector<AffineMap> fusedIndexMaps; 265 fusedOperands.reserve(producer->getNumOperands() + 266 consumer->getNumOperands()); 267 fusedIndexMaps.reserve(producer->getNumOperands() + 268 consumer->getNumOperands()); 269 // In the following, numbering matches that of `generateFusedTensorOpRegion`. 270 // 3. Consumer input operands/maps up to consumerIdx (exclusive). 271 SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands(); 272 SmallVector<OpOperand *>::iterator it = 273 llvm::find(consumerInputs, consumerOpOperand); 274 assert(it != consumerInputs.end() && "expected to find the consumer operand"); 275 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { 276 fusedOperands.push_back(opOperand->get()); 277 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 278 } 279 // 4. Splice in producer's input operands/maps. 280 assert(producer->getNumResults() == 1 && "expected single result producer"); 281 AffineMap producerResultIndexMap = 282 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 283 for (OpOperand *opOperand : producer.getInputOperands()) { 284 fusedOperands.push_back(opOperand->get()); 285 // Compute indexing maps for the producer args in the fused operation. 286 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 287 opOperand, producerResultIndexMap, 288 consumer.getTiedIndexingMap(consumerOpOperand)); 289 fusedIndexMaps.push_back(map); 290 } 291 // 4.b. Producer output operand/map that is fused needs to be passed if it is 292 // an "initTensor" (i.e. its value is actually read). 293 assert(producer->getNumResults() == 1 && "expected single result producer"); 294 if (producer.isInitTensor(producer.getOutputOperand(0))) { 295 fusedOperands.push_back(producer.getOutputOperand(0)->get()); 296 // Compute indexing maps for the producer args in the fused operation. 297 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 298 producer.getOutputOperand(0), producerResultIndexMap, 299 consumer.getTiedIndexingMap(consumerOpOperand)); 300 fusedIndexMaps.push_back(map); 301 } 302 // 5. Remaining consumer's input operands/maps (drop past index 303 // `consumerIdx`). 304 for (OpOperand *opOperand : 305 llvm::make_range(std::next(it), consumerInputs.end())) { 306 fusedOperands.push_back(opOperand->get()); 307 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 308 } 309 // 6. All of consumer's output operands (skip operands: added by the builder). 310 for (OpOperand *opOperand : consumer.getOutputOperands()) 311 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 312 // 7. All of producer's output operands/maps except the one fused. 313 // TODO: allow fusion of multi-result producers. 314 assert(producer->getNumResults() == 1 && "expected single result producer"); 315 316 // Generate the fused op. 317 SmallVector<Value> consumerOutputs = consumer.getOutputOperands(); 318 auto fusedOp = rewriter.create<GenericOp>( 319 consumer.getLoc(), consumer->getResultTypes(), 320 /*inputs=*/fusedOperands, 321 // TODO: handle outputs. 322 consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 323 consumer.iterator_types(), 324 /*doc=*/nullptr, 325 /*library_call=*/nullptr); 326 if (!fusedOp.getShapesToLoopsMap()) { 327 // Fused op has invalid indexing maps. Typically this means something is off 328 // in the input, but going ahead here would result in verification errors. 329 // So cleanup and abort. 330 rewriter.eraseOp(fusedOp); 331 return llvm::None; 332 } 333 334 // Construct an AffineMap from consumer loops to producer loops. 335 // consumer loop -> tensor index 336 AffineMap consumerResultIndexMap = 337 consumer.getTiedIndexingMap(consumerOpOperand); 338 // tensor index -> producer loop 339 AffineMap invProducerResultIndexMap = 340 inversePermutation(producerResultIndexMap); 341 assert(invProducerResultIndexMap && 342 "expected producer result indexig map to be invertible"); 343 // consumer loop -> producer loop 344 AffineMap consumerToProducerLoopsMap = 345 invProducerResultIndexMap.compose(consumerResultIndexMap); 346 347 generateFusedElementwiseOpRegion(rewriter, fusedOp, 348 consumerToProducerLoopsMap, 349 consumerOpOperand, consumer.getNumLoops()); 350 return SmallVector<Value>(fusedOp->getResults()); 351 } 352 353 static Optional<SmallVector<Value>> 354 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, 355 GenericOp producer, const ControlFusionFn &controlFn) { 356 if (producer->getNumResults() != 1) 357 return llvm::None; 358 359 return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, 360 rewriter); 361 } 362 363 namespace { 364 /// Patterns to fuse a generic op, with the producer of its operands. 365 class FuseElementwiseOps : public OpRewritePattern<GenericOp> { 366 public: 367 FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun, 368 PatternBenefit benefit = 1) 369 : OpRewritePattern<GenericOp>(context, benefit), 370 controlFn(std::move(fun)) {} 371 372 LogicalResult matchAndRewrite(GenericOp genericOp, 373 PatternRewriter &rewriter) const override { 374 // Find the first operand that is defined by another generic op on tensors. 375 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 376 auto producer = 377 dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp()); 378 if (!producer || !producer.hasTensorSemantics()) 379 continue; 380 Optional<SmallVector<Value>> fusedOpResults = 381 fuseElementwiseOps(rewriter, opOperand, producer, controlFn); 382 if (fusedOpResults) { 383 rewriter.replaceOp(genericOp, *fusedOpResults); 384 return success(); 385 } 386 } 387 return failure(); 388 } 389 390 private: 391 ControlFusionFn controlFn; 392 }; 393 } // namespace 394 395 //===---------------------------------------------------------------------===// 396 // Methods and patterns that fuse reshape ops with elementwise operations by 397 // expanding the dimensionality of the elementwise operations. 398 //===---------------------------------------------------------------------===// 399 400 /// Conditions for folding a generic operation with a reshape op by expanding 401 /// the iteration space dimensionality for tensor operations. These are 402 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the 403 /// following fusion pattern. 404 /// 405 /// Consider 406 /// 407 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>) 408 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 409 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 410 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] 411 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] 412 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 413 /// 414 /// The reshape can be folded into the `genericOp` if its loop dimensionality 415 /// is increased to match the result (operand) of the tensor.expand_shape. 416 /// The indexing_map of the fused tensor in the `genericOp` and the 417 /// reassociation map helps compute the indexing maps of the modified op. 418 /// For the above example, based on the reassociation map it 419 /// can be concluded that 420 /// 421 /// - The loop used to access the first dimension of the fused tensor is split 422 /// into two. 423 /// - The loop used to access the second dimension of the fused tensor is kept 424 /// as is. 425 /// - The loop used to access the third dimension of the fused tensor is split 426 /// into three. 427 /// 428 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified 429 /// op, then 430 /// 431 /// d0 -> e0, e1 432 /// d1 -> e2, e3, e4 433 /// d2 -> e5 434 /// 435 /// substituting this, the generic op can be rewritten as 436 /// 437 /// %d = linalg.generic ins(%0, %1 : ) 438 /// indexing_maps = 439 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, 440 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, 441 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] 442 /// 443 /// Since operands to the linalg generic are now 5D, reshapes can be introduced 444 /// to make it consistent 445 /// 446 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]] 447 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 448 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]] 449 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 450 /// 451 /// The added reshapes are again expanding patterns, so they will get fused 452 /// with its producers if possible. 453 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, 454 OpOperand *fusableOpOperand) { 455 // Is fusable only if: 456 // - All the indexing maps for operands and results are projected 457 // permutations. 458 // - The fused tensor is not a scalar. 459 // - All the loops are parallel loops. 460 return genericOp.hasTensorSemantics() && 461 llvm::all_of(genericOp.indexing_maps().getValue(), 462 [](Attribute attr) { 463 return attr.cast<AffineMapAttr>() 464 .getValue() 465 .isProjectedPermutation(); 466 }) && 467 genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && 468 llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { 469 return attr.cast<StringAttr>().getValue() == 470 getParallelIteratorTypeName(); 471 }); 472 } 473 474 namespace { 475 /// Information needed to expand a generic operation to fold the reshape with 476 /// it. 477 class ExpansionInfo { 478 public: 479 // Computes the mapping from original dimensions of the op to the dimensions 480 // of the expanded op given the `indexingMap` of the fused operand/result of 481 // the generic op, the `reassocationMaps` of the reshape op and the shape of 482 // the expanded op. 483 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, 484 ArrayRef<AffineMap> reassociationMaps, 485 ArrayRef<int64_t> expandedShape, 486 ArrayRef<int64_t> collapsedShape, 487 PatternRewriter &rewriter); 488 unsigned getOrigOpNumDims() const { return reassociation.size(); } 489 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } 490 ReassociationIndicesRef getExpandedDims(unsigned i) const { 491 return reassociation[i]; 492 } 493 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { 494 return expandedShapeMap[i]; 495 } 496 ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; } 497 498 private: 499 /// Reassociation from the dimensions in the original operation to the 500 /// dimension of the expanded operation. 501 SmallVector<ReassociationIndices> reassociation; 502 /// Mapping from extent of loops in the original operation, to the extent of 503 /// loops in the expanded operation. 504 SmallVector<SmallVector<int64_t>> expandedShapeMap; 505 /// Extent of the loop in the original operation. 506 SmallVector<int64_t> originalLoopExtent; 507 unsigned expandedOpNumDims; 508 }; 509 } // namespace 510 511 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, 512 OpOperand *fusableOpOperand, 513 ArrayRef<AffineMap> reassociationMaps, 514 ArrayRef<int64_t> expandedShape, 515 ArrayRef<int64_t> collapsedShape, 516 PatternRewriter &rewriter) { 517 if (reassociationMaps.empty()) 518 return failure(); 519 AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); 520 521 SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges(); 522 originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end()); 523 524 reassociation.clear(); 525 expandedShapeMap.clear(); 526 // Compute the number of dimension in the expanded op that correspond to each 527 // dimension of the original op. 528 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1); 529 expandedShapeMap.resize(fusedIndexMap.getNumDims()); 530 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { 531 unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition(); 532 AffineMap foldedDims = reassociationMaps[resultExpr.index()]; 533 numExpandedDims[pos] = foldedDims.getNumResults(); 534 ArrayRef<int64_t> shape = 535 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); 536 expandedShapeMap[pos].assign(shape.begin(), shape.end()); 537 } 538 // The remaining dimensions remain the same. 539 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims())) 540 if (expandedShapeMap[i].empty()) 541 expandedShapeMap[i] = {originalLoopExtent[i]}; 542 543 // Compute reassociation map from the original op to the expanded op. 544 unsigned sum = 0; 545 reassociation.reserve(fusedIndexMap.getNumDims()); 546 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) { 547 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value()); 548 reassociation.emplace_back(seq.begin(), seq.end()); 549 sum += numFoldedDim.value(); 550 } 551 expandedOpNumDims = sum; 552 return success(); 553 } 554 555 /// Epanding the body of a linalg operation requires adaptations of the accessed 556 /// loop indices. Specifically, access of indices in the original operation need 557 /// to be replaced with linearizations of indices in the expanded op. That 558 /// requires the shape of the expanded dimensions to be static (at least all but 559 /// the most significant). For now check that these are all statically sized. 560 /// Note that this could be extended to handle dynamic case, but the 561 /// implementation below uses `affine.apply` which seems to have issues when the 562 /// shapes are not static. 563 static LogicalResult isGenericOpExpandable(GenericOp genericOp, 564 const ExpansionInfo &expansionInfo, 565 PatternRewriter &rewriter) { 566 if (!genericOp.hasIndexSemantics()) 567 return success(); 568 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 569 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 570 if (expandedShape.size() == 1) 571 continue; 572 for (int64_t shape : expandedShape.drop_front()) { 573 if (ShapedType::isDynamic(shape)) { 574 return rewriter.notifyMatchFailure( 575 genericOp, "cannot expand due to index semantics and dynamic dims"); 576 } 577 } 578 } 579 return success(); 580 } 581 582 /// Return the indexing map to use in the expanded op for a given the 583 /// `indexingMap` of the original operation. 584 static AffineMap 585 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, 586 const ExpansionInfo &expansionInfo) { 587 SmallVector<AffineExpr> newExprs; 588 for (AffineExpr expr : indexingMap.getResults()) { 589 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 590 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>( 591 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { 592 return builder.getAffineDimExpr(static_cast<unsigned>(v)); 593 })); 594 newExprs.append(expandedExprs.begin(), expandedExprs.end()); 595 } 596 return AffineMap::get(expansionInfo.getExpandedOpNumDims(), 597 indexingMap.getNumSymbols(), newExprs, 598 builder.getContext()); 599 } 600 601 /// Return the type of the operand/result to use in the expanded op given the 602 /// type in the original op. 603 static RankedTensorType getExpandedType(RankedTensorType originalType, 604 AffineMap indexingMap, 605 const ExpansionInfo &expansionInfo) { 606 SmallVector<int64_t> expandedShape; 607 for (AffineExpr expr : indexingMap.getResults()) { 608 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 609 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); 610 expandedShape.append(dimExpansion.begin(), dimExpansion.end()); 611 } 612 return RankedTensorType::get(expandedShape, originalType.getElementType()); 613 } 614 615 /// Returns the reassociation maps to use in the `tensor.expand_shape` 616 /// operation to convert the operands of the original operation to operands of 617 /// the expanded operation. The same method is used to compute the 618 /// `tensor.collapse_shape` used to collapse the result of the expanded 619 /// op to get the value that can replace all uses of the results of the original 620 /// op. 621 static SmallVector<ReassociationIndices> 622 getReassociationForExpansion(AffineMap indexingMap, 623 const ExpansionInfo &expansionInfo) { 624 SmallVector<ReassociationIndices> reassociation; 625 unsigned numReshapeDims = 0; 626 for (AffineExpr expr : indexingMap.getResults()) { 627 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 628 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); 629 SmallVector<int64_t, 2> indices = llvm::to_vector<2>( 630 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims)); 631 reassociation.emplace_back(std::move(indices)); 632 numReshapeDims += numExpandedDims; 633 } 634 return reassociation; 635 } 636 637 /// Update the body of an expanded linalg operation having index semantics. The 638 /// indices of the original operation need to be recovered by linearizing the 639 /// indices of the correspoding dimensions of the expanded operation. For now it 640 /// is assumed that the shapes of the expanded operation needed for 641 /// linearization are static. 642 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, 643 Location loc, Region &fusedRegion, 644 const ExpansionInfo &expansionInfo) { 645 // Replace the original indices by the linearization of the expanded indices. 646 for (IndexOp indexOp : 647 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) { 648 ArrayRef<int64_t> expandedDims = 649 expansionInfo.getExpandedDims(indexOp.dim()); 650 assert(!expandedDims.empty() && "expected valid expansion info"); 651 652 // Skip index operations that are not affected by the expansion. 653 if (expandedDims.size() == 1 && 654 expandedDims.front() == (int64_t)indexOp.dim()) 655 continue; 656 657 // Linearize the expanded indices of the original index dimension. 658 OpBuilder::InsertionGuard guard(rewriter); 659 rewriter.setInsertionPointAfter(indexOp); 660 ArrayRef<int64_t> expandedDimsShape = 661 expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front(); 662 SmallVector<Value> expandedIndices; 663 expandedIndices.reserve(expandedDims.size() - 1); 664 llvm::transform( 665 expandedDims.drop_front(), std::back_inserter(expandedIndices), 666 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); 667 Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front()); 668 for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { 669 assert(!ShapedType::isDynamic(std::get<0>(it))); 670 AffineExpr idx, acc; 671 bindDims(rewriter.getContext(), idx, acc); 672 newIndex = rewriter.create<AffineApplyOp>( 673 indexOp.getLoc(), idx + acc * std::get<0>(it), 674 ValueRange{std::get<1>(it), newIndex}); 675 } 676 rewriter.replaceOp(indexOp, newIndex); 677 } 678 } 679 680 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op 681 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes 682 /// that those conditions have been satisfied. 683 static Optional<SmallVector<Value>> 684 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, 685 OpOperand *fusableOpOperand, 686 PatternRewriter &rewriter) { 687 assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) && 688 "preconditions for fuse operation failed"); 689 // Check if reshape is expanding or collapsing. 690 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp); 691 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp); 692 bool isExpanding = (expandingReshapeOp != nullptr); 693 RankedTensorType expandedType = isExpanding 694 ? expandingReshapeOp.getResultType() 695 : collapsingReshapeOp.getSrcType(); 696 RankedTensorType collapsedType = isExpanding 697 ? expandingReshapeOp.getSrcType() 698 : collapsingReshapeOp.getResultType(); 699 700 ExpansionInfo expansionInfo; 701 if (failed(expansionInfo.compute( 702 genericOp, fusableOpOperand, 703 isExpanding ? expandingReshapeOp.getReassociationMaps() 704 : collapsingReshapeOp.getReassociationMaps(), 705 expandedType.getShape(), collapsedType.getShape(), rewriter))) 706 return llvm::None; 707 708 if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) 709 return llvm::None; 710 711 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( 712 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) { 713 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); 714 })); 715 716 SmallVector<Value> expandedOpOperands; 717 expandedOpOperands.reserve(genericOp.getNumInputs()); 718 for (OpOperand *opOperand : genericOp.getInputOperands()) { 719 if (opOperand == fusableOpOperand) { 720 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() 721 : collapsingReshapeOp.src()); 722 continue; 723 } 724 if (genericOp.isInputTensor(opOperand)) { 725 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 726 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 727 RankedTensorType expandedOperandType = 728 getExpandedType(opOperandType, indexingMap, expansionInfo); 729 if (expandedOperandType != opOperand->get().getType()) { 730 // Reshape the operand to get the right type. 731 SmallVector<ReassociationIndices> reassociation = 732 getReassociationForExpansion(indexingMap, expansionInfo); 733 if (failed(reshapeLikeShapesAreCompatible( 734 [&](const Twine &msg) { 735 return rewriter.notifyMatchFailure(genericOp, msg); 736 }, 737 opOperandType.getShape(), expandedOperandType.getShape(), 738 reassociation, 739 /*isExpandingReshape=*/true))) 740 return llvm::None; 741 expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>( 742 genericOp.getLoc(), expandedOperandType, opOperand->get(), 743 reassociation)); 744 continue; 745 } 746 } 747 expandedOpOperands.push_back(opOperand->get()); 748 } 749 750 Location loc = genericOp.getLoc(); 751 SmallVector<Value> outputs; 752 for (OpOperand *opOperand : genericOp.getOutputOperands()) { 753 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 754 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 755 RankedTensorType expandedOutputType = 756 getExpandedType(opOperandType, indexingMap, expansionInfo); 757 if (expandedOutputType != opOperand->get().getType()) { 758 SmallVector<ReassociationIndices> reassociation = 759 getReassociationForExpansion(indexingMap, expansionInfo); 760 if (failed(reshapeLikeShapesAreCompatible( 761 [&](const Twine &msg) { 762 return rewriter.notifyMatchFailure(genericOp, msg); 763 }, 764 opOperandType.getShape(), expandedOutputType.getShape(), 765 reassociation, 766 /*isExpandingReshape=*/true))) 767 return llvm::None; 768 outputs.push_back(rewriter.create<tensor::ExpandShapeOp>( 769 genericOp.getLoc(), expandedOutputType, opOperand->get(), 770 reassociation)); 771 } 772 } 773 774 // The iterator types of the expanded op are all parallel. 775 SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(), 776 getParallelIteratorTypeName()); 777 778 TypeRange resultTypes = ValueRange(outputs).getTypes(); 779 auto fusedOp = 780 rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes, 781 /*inputs=*/expandedOpOperands, outputs, 782 expandedOpIndexingMaps, iteratorTypes); 783 Region &fusedRegion = fusedOp->getRegion(0); 784 Region &originalRegion = genericOp->getRegion(0); 785 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); 786 787 // Update the index accesses after the expansion. 788 updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); 789 790 // Reshape the result values to their original shape if this is a collapsing 791 // reshape folded into its consumer. 792 SmallVector<Value> resultVals; 793 for (OpResult opResult : genericOp->getOpResults()) { 794 int64_t resultNumber = opResult.getResultNumber(); 795 if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { 796 SmallVector<ReassociationIndices> reassociation = 797 getReassociationForExpansion( 798 genericOp.getTiedIndexingMap( 799 genericOp.getOutputOperand(resultNumber)), 800 expansionInfo); 801 resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>( 802 genericOp.getLoc(), opResult.getType(), 803 fusedOp->getResult(resultNumber), reassociation)); 804 } else { 805 resultVals.push_back(fusedOp->getResult(resultNumber)); 806 } 807 } 808 // Assuming a single result. 809 return resultVals; 810 } 811 812 namespace { 813 814 /// Pattern to fuse a tensor.collapse_shape op with its consumer generic op, 815 /// when the reshape op is collapsing dimensions. The dimensionality of the loop 816 /// in the consumer is expanded. 817 class FoldWithProducerReshapeOpByExpansion 818 : public OpRewritePattern<GenericOp> { 819 public: 820 FoldWithProducerReshapeOpByExpansion(MLIRContext *context, 821 ControlFusionFn foldReshapes, 822 PatternBenefit benefit = 1) 823 : OpRewritePattern<GenericOp>(context, benefit), 824 controlFoldingReshapes(std::move(foldReshapes)) {} 825 826 LogicalResult matchAndRewrite(GenericOp genericOp, 827 PatternRewriter &rewriter) const override { 828 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 829 tensor::CollapseShapeOp reshapeOp = 830 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>(); 831 if (!reshapeOp) 832 continue; 833 // Fold only if 834 // - The tensor reshape op is folding. 835 // - All constraints of fusing with reshape by expansion are met. 836 if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || 837 (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) 838 continue; 839 840 Optional<SmallVector<Value>> replacementValues = 841 fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); 842 if (!replacementValues) 843 return failure(); 844 rewriter.replaceOp(genericOp, replacementValues.getValue()); 845 return success(); 846 } 847 return failure(); 848 } 849 850 private: 851 ControlFusionFn controlFoldingReshapes; 852 }; 853 854 /// Pattern to fold a tensor.expand_shape op with its producer generic op 855 /// by expanding the dimensionality of the loop in the producer op. 856 struct FoldReshapeWithGenericOpByExpansion 857 : public OpRewritePattern<tensor::ExpandShapeOp> { 858 859 FoldReshapeWithGenericOpByExpansion(MLIRContext *context, 860 ControlFusionFn foldReshapes, 861 PatternBenefit benefit = 1) 862 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), 863 controlFoldingReshapes(std::move(foldReshapes)) {} 864 865 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, 866 PatternRewriter &rewriter) const override { 867 // Fold only if all constraints of fusing with reshape by expansion are met. 868 GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>(); 869 if (!producer || producer.getNumOutputs() != 1 || 870 !isFusableWithReshapeByDimExpansion(producer, 871 producer.getOutputOperand(0)) || 872 !controlFoldingReshapes(producer->getResult(0), 873 reshapeOp->getOpOperand(0))) 874 return failure(); 875 Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion( 876 producer, reshapeOp, producer.getOutputOperand(0), rewriter); 877 if (!replacementValues) 878 return failure(); 879 rewriter.replaceOp(reshapeOp, replacementValues.getValue()); 880 return success(); 881 } 882 883 private: 884 ControlFusionFn controlFoldingReshapes; 885 }; 886 } // namespace 887 888 //===---------------------------------------------------------------------===// 889 // Methods and patterns to fuse reshape with linalg.generic operations by 890 // contraction of dimensions. 891 //===---------------------------------------------------------------------===// 892 893 /// For a given list of indices in the range of the `indexingMap` that are 894 /// folded, return the indices of the corresponding domain. Return `llvm::None` 895 /// on failure. Ensures that all the elements of the returned reassociation are 896 /// distinct. 897 static ReassociationIndices 898 getDomainReassociation(AffineMap indexingMap, 899 ReassociationIndicesRef rangeReassociation) { 900 assert(indexingMap.isProjectedPermutation() && 901 "expected projected permutation"); 902 903 ReassociationIndices domainReassociation = llvm::to_vector<4>( 904 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t { 905 return indexingMap.getResults()[pos] 906 .cast<AffineDimExpr>() 907 .getPosition(); 908 })); 909 // The projected permutation semantics ensures that there is no repetition of 910 // the domain indices. 911 return domainReassociation; 912 } 913 914 /// For a given `dimSequence`, check if the sequence is conserved in the 915 /// `indexingMap`. `indexingMap` is expected to be a projected permutation. 916 /// Non-existence of the sequence returns true as well. 917 static bool isDimSequencePreserved(AffineMap indexingMap, 918 ReassociationIndicesRef dimSequence) { 919 assert(!dimSequence.empty() && 920 "expected non-empty list for dimension sequence"); 921 assert(indexingMap.isProjectedPermutation() && 922 "expected indexing map to be projected permutation"); 923 924 llvm::SmallDenseSet<unsigned, 4> sequenceElements; 925 sequenceElements.insert(dimSequence.begin(), dimSequence.end()); 926 927 unsigned dimSequenceStart = dimSequence[0]; 928 for (const auto &expr : enumerate(indexingMap.getResults())) { 929 unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition(); 930 // 1. Check if this start of the sequence. 931 if (dimInMapStart == dimSequenceStart) { 932 if (expr.index() + dimSequence.size() > indexingMap.getNumResults()) 933 return false; 934 // 1a. Check if sequence is preserved. 935 for (const auto &dimInSequence : enumerate(dimSequence)) { 936 unsigned dimInMap = 937 indexingMap.getResult(expr.index() + dimInSequence.index()) 938 .cast<AffineDimExpr>() 939 .getPosition(); 940 if (dimInMap != dimInSequence.value()) 941 return false; 942 } 943 // Found the sequence. Projected permutation 944 // enforces that all AffineDimExprs in the result are unique, so no 945 // further checks are needed. 946 return true; 947 } 948 // 2. If position in the expr (which is of type AffineDimExpr) is part 949 // of sequence, return false here. This implies the entire sequence does not 950 // exist in the indexing map. 951 if (sequenceElements.count(dimInMapStart)) 952 return false; 953 } 954 // 3. No element of sequence found. Return true. 955 return true; 956 } 957 958 // Return the list of dimensions of the iteration domain that can be 959 // collapsed to allow for fusion with the a producer that is an expand_shape 960 // operation. If all dimensions created by expansion can be collapsed in the 961 // iteration space then the reshape is defunct. 962 // 963 // Example: 964 // 965 // ```mlir 966 // #map = affine_map<(d0, d1) -> (d0, d1)> 967 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 968 // %2 = linalg.init_tensor [..] : tensor<?x4xf32> 969 // %3 = linalg.generic { 970 // indexing_maps = [#map, #map], 971 // iterator_types = ["parallel" ,"parallel"]} 972 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. } 973 // ``` 974 // 975 // can be fused by collapsing the dimensions of the iteration space. 976 // 977 // ```mlir 978 // #map = affine_map<(d0) -> (d0)> 979 // %2 = linalg.init_tensor [..] : tensor<?xf32> 980 // %3 = linalg.generic { 981 // indexing_maps = [#map, #map], 982 // iterator_types = ["parallel"]} 983 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. } 984 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 985 // ``` 986 // 987 // In the following example, 988 // 989 // ```mlir 990 // #map0 = affine_map<(d0, d1) -> (d0, d1)> 991 // #map1 = affine_map<(d0, d1) -> (d1, d0)> 992 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 993 // %2 = linalg.init_tensor [..] : tensor<4x?xf32> 994 // %2 = linalg.generic { 995 // indexing_maps = [#map0, #map1], 996 // iterator_types = ["parallel" ,"parallel"]} 997 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. } 998 // ``` 999 // 1000 // the reshape cannot be fused with the generic op by collapsing the op 1001 // dimensions since the indexing maps will have to contain mods and divs 1002 // to preserve the accesses pattern. When no dimensions of the iteration 1003 // space are collapsable and empty vector is returned. 1004 static SmallVector<ReassociationIndices> 1005 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, 1006 ArrayRef<ReassociationIndices> reassociation) { 1007 // Some basic checks for this fusion to be valid. 1008 if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1) 1009 return {}; 1010 1011 if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) { 1012 return map.isProjectedPermutation(); 1013 })) { 1014 return {}; 1015 } 1016 1017 // Compute all the loops with the reduction iterator types. 1018 SmallVector<int64_t> reductionDims; 1019 for (const auto &iteratorType : llvm::enumerate(genericOp.iterator_types())) { 1020 if (isReductionIterator(iteratorType.value())) { 1021 reductionDims.push_back(iteratorType.index()); 1022 } 1023 } 1024 1025 llvm::SmallDenseSet<unsigned, 4> processedIterationDims; 1026 AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand); 1027 auto iteratorTypes = genericOp.iterator_types().getValue(); 1028 SmallVector<ReassociationIndices> iterationSpaceReassociation; 1029 for (ReassociationIndicesRef foldedRangeDims : reassociation) { 1030 assert(!foldedRangeDims.empty() && "unexpected empty reassociation"); 1031 1032 // Ignore dims that are not folded. 1033 if (foldedRangeDims.size() == 1) 1034 continue; 1035 1036 ReassociationIndices foldedIterationSpaceDims = 1037 getDomainReassociation(indexingMap, foldedRangeDims); 1038 1039 // Check that the folded iteration dims do not contain already processed 1040 // dims. 1041 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) { 1042 return processedIterationDims.count(dim); 1043 })) 1044 continue; 1045 1046 // Check that all folded iterator types are all parallel or all reductions. 1047 Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]]; 1048 if (!isParallelIterator(startIteratorType) && 1049 !isReductionIterator(startIteratorType)) 1050 continue; 1051 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) { 1052 return iteratorTypes[dim] != startIteratorType; 1053 })) 1054 continue; 1055 1056 // If the folded dimensions correspond to a "reduction" iterator type, 1057 // the folded dimensions need to be "in-order". Strictly speaking this is 1058 // not necessary, for reductions that are associative and commutative, but 1059 // using a more strict definition of reduction for now. 1060 if (isReductionIterator(startIteratorType)) { 1061 bool isContiguous = false; 1062 for (const auto &startDim : llvm::enumerate(reductionDims)) { 1063 // Move window in `reductionDims` to start of the folded iteration dims. 1064 if (startDim.value() != foldedIterationSpaceDims[0]) 1065 continue; 1066 // If sizes doesnt match, trivial not contiguous. This condition should 1067 // not be hit. 1068 if (startDim.index() + foldedIterationSpaceDims.size() > 1069 reductionDims.size()) 1070 break; 1071 // Check that the contiguity is maintained. 1072 isContiguous = true; 1073 for (const auto &foldedDim : 1074 llvm::enumerate(foldedIterationSpaceDims)) { 1075 if (reductionDims[foldedDim.index() + startDim.index()] != 1076 foldedDim.value()) { 1077 isContiguous = false; 1078 break; 1079 } 1080 } 1081 break; 1082 } 1083 if (!isContiguous) 1084 continue; 1085 } 1086 1087 // Check that the sequence is preserved in all indexing maps. 1088 if (llvm::any_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) { 1089 return !isDimSequencePreserved(indexingMap, foldedIterationSpaceDims); 1090 })) 1091 continue; 1092 1093 processedIterationDims.insert(foldedIterationSpaceDims.begin(), 1094 foldedIterationSpaceDims.end()); 1095 iterationSpaceReassociation.emplace_back( 1096 std::move(foldedIterationSpaceDims)); 1097 } 1098 1099 return iterationSpaceReassociation; 1100 } 1101 1102 /// Helper class to carry state while collapsing the `linalg.generic` op. 1103 namespace { 1104 class CollapsingInfo { 1105 public: 1106 LogicalResult initialize(unsigned origNumLoops, 1107 ArrayRef<ReassociationIndices> foldedIterationDims) { 1108 llvm::SmallDenseSet<int64_t, 4> processedDims; 1109 // Find all the dims that are folded. 1110 for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) { 1111 if (foldedIterationDim.empty()) 1112 continue; 1113 // If the folded dims contain dims already folded, that's illegal 1114 // specification. Repetition within a list is also illegal. 1115 for (auto dim : foldedIterationDim) { 1116 if (dim >= origNumLoops) 1117 return failure(); 1118 if (processedDims.count(dim)) 1119 return failure(); 1120 processedDims.insert(dim); 1121 } 1122 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(), 1123 foldedIterationDim.end()); 1124 } 1125 if (processedDims.size() > origNumLoops) 1126 return failure(); 1127 1128 // Add all the preserved dims of the original op as single 1129 // elements to `collapsedOpToOrigOpIterationDim`. 1130 for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) { 1131 if (processedDims.count(dim)) 1132 continue; 1133 collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim}); 1134 } 1135 1136 llvm::sort(collapsedOpToOrigOpIterationDim, 1137 [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) { 1138 return lhs[0] < rhs[0]; 1139 }); 1140 origOpToCollapsedOpIterationDim.resize(origNumLoops); 1141 for (const auto &foldedDims : 1142 llvm::enumerate(collapsedOpToOrigOpIterationDim)) { 1143 for (const auto &dim : enumerate(foldedDims.value())) 1144 origOpToCollapsedOpIterationDim[dim.value()] = 1145 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index()); 1146 } 1147 return success(); 1148 } 1149 1150 /// Return mapping from collapsed loop domain to original loop domain. 1151 ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const { 1152 return collapsedOpToOrigOpIterationDim; 1153 } 1154 1155 /// Return mapping from original loop domain to collapsed loop domain. The 1156 /// mapping is a pair. First value is the dimension in the collapsed loop that 1157 /// the original loop is mapped to. Second is the relative position in folded 1158 /// list of this domain. For example if the original loop domain is 3D, and 1159 /// the collapsed loop domain is folding all of it, i.e. 1160 /// 1161 /// ``` 1162 /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]` 1163 /// ``` 1164 /// 1165 /// then 1166 /// 1167 /// ``` 1168 /// origOpToCollapsedOpMapping[0] = {0, 0}; 1169 /// origOpToCollapsedOpMapping[1] = {0, 1}; 1170 /// origOpToCollapsedOpMapping[2] = {0, 2}; 1171 /// origOpToCollapsedOpMapping[3] = {1, 0}; 1172 /// origOpToCollapsedOpMapping[4] = {1, 1}; 1173 /// ``` 1174 /// 1175 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const { 1176 return origOpToCollapsedOpIterationDim; 1177 } 1178 1179 /// Return the collapsed op iteration domain rank. 1180 unsigned getCollapsedOpIterationRank() const { 1181 return collapsedOpToOrigOpIterationDim.size(); 1182 } 1183 1184 private: 1185 /// Map from the iteration domain index in collapsed op to the iteration 1186 /// domain indices in the original op. 1187 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim; 1188 1189 /// Map from iteration domain index in the original op to the iteration domain 1190 /// index in the collapsed op. 1191 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim; 1192 }; 1193 } // namespace 1194 1195 /// Get the iterator types for the collapsed operation given the original 1196 /// iterator types and collapsed dimensions. 1197 static SmallVector<StringRef> 1198 getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes, 1199 const CollapsingInfo &collapsingInfo) { 1200 SmallVector<StringRef> collapsedIteratorTypes; 1201 for (ReassociationIndicesRef foldedIterDims : 1202 collapsingInfo.getCollapsedOpToOrigOpMapping()) { 1203 assert(!foldedIterDims.empty() && 1204 "reassociation indices expected to have non-empty sets"); 1205 // Just pick the iterator type of the first folded dim. Pre-condition checks 1206 // expected to have checked that iterator types of all folded dimensions are 1207 // the same. 1208 collapsedIteratorTypes.push_back( 1209 iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue()); 1210 } 1211 return collapsedIteratorTypes; 1212 } 1213 1214 /// Compute the indexing map in the collapsed op that corresponds to the given 1215 /// `indexingMap` of the original operation. 1216 static AffineMap 1217 getCollapsedOpIndexingMap(AffineMap indexingMap, 1218 const CollapsingInfo &collapsingInfo) { 1219 MLIRContext *context = indexingMap.getContext(); 1220 assert(indexingMap.isProjectedPermutation() && 1221 "expected indexing map to be projected permutation"); 1222 SmallVector<AffineExpr> resultExprs; 1223 auto origOpToCollapsedOpMapping = 1224 collapsingInfo.getOrigOpToCollapsedOpMapping(); 1225 for (auto expr : indexingMap.getResults()) { 1226 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 1227 // If the dim is not the first of the collapsed dim, do nothing. 1228 if (origOpToCollapsedOpMapping[dim].second != 0) 1229 continue; 1230 // The next n-dims are guaranteed to be collapsed. So just use the 1231 // iteration dimension of the collapsed op. 1232 resultExprs.push_back( 1233 getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context)); 1234 } 1235 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0, 1236 resultExprs, context); 1237 } 1238 1239 /// Return the `reassociation` indices to use to collapse the operand when the 1240 /// iteration space of a generic op is collapsed. 1241 static SmallVector<ReassociationIndices> 1242 getOperandReassociation(AffineMap indexingMap, 1243 const CollapsingInfo &collapsingInfo) { 1244 unsigned counter = 0; 1245 SmallVector<ReassociationIndices> operandReassociation; 1246 auto origOpToCollapsedOpMapping = 1247 collapsingInfo.getOrigOpToCollapsedOpMapping(); 1248 auto collapsedOpToOrigOpMapping = 1249 collapsingInfo.getCollapsedOpToOrigOpMapping(); 1250 while (counter < indexingMap.getNumResults()) { 1251 unsigned dim = 1252 indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition(); 1253 if (origOpToCollapsedOpMapping[dim].second == 0) { 1254 // This is the start of a collapsed dimensions of the iteration that 1255 // is gauranteed to be preserved in the indexing map. The number of folded 1256 // dims is obtained from the collapsed op to original op mapping. 1257 unsigned numFoldedDims = 1258 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first] 1259 .size(); 1260 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims); 1261 operandReassociation.emplace_back(range.begin(), range.end()); 1262 counter += numFoldedDims; 1263 } 1264 } 1265 return operandReassociation; 1266 } 1267 1268 /// Get the new value to use for a given `OpOperand` in the collapsed operation. 1269 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, 1270 OpOperand *opOperand, 1271 const CollapsingInfo &collapsingInfo, 1272 OpBuilder &builder) { 1273 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 1274 SmallVector<ReassociationIndices> operandReassociation = 1275 getOperandReassociation(indexingMap, collapsingInfo); 1276 1277 // If the number of entries in the reassocation for the operand is same as the 1278 // number of results of the indexing map, then nothing to do for this operand. 1279 Value operand = opOperand->get(); 1280 if (operandReassociation.size() == indexingMap.getNumResults()) 1281 return operand; 1282 1283 // Insert a reshape to collapse the dimensions. 1284 auto reshapeOp = builder.create<tensor::CollapseShapeOp>( 1285 loc, operand, operandReassociation); 1286 return reshapeOp.getResult(); 1287 } 1288 1289 /// Modify the `linalg.index` operations in the original generic op, to its 1290 /// value in the collapsed operation. 1291 void generateCollapsedIndexingRegion(Location loc, Block *block, 1292 const CollapsingInfo &collapsingInfo, 1293 ValueRange loopRange, 1294 PatternRewriter &rewriter) { 1295 OpBuilder::InsertionGuard g(rewriter); 1296 rewriter.setInsertionPointToStart(block); 1297 1298 // Collect all the original index ops. 1299 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>()); 1300 1301 // For each folded dimension list resolve the original induction variable 1302 // values in terms of the folded dimension induction variable. 1303 // i_{folded} = (i_0 * d1 + i1) * d2 + i2. 1304 // can be inverted to 1305 // i2 = i_{folded} % d2 1306 // i1 = (i_{folded} / d2) % d1 1307 // i0 = i_{folded} / (d1 * d2) 1308 llvm::DenseMap<unsigned, Value> indexReplacementVals; 1309 for (auto &foldedDims : 1310 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) { 1311 ReassociationIndicesRef foldedDimsRef(foldedDims.value()); 1312 Value newIndexVal = 1313 rewriter.create<linalg::IndexOp>(loc, foldedDims.index()); 1314 for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) { 1315 indexReplacementVals[dim] = 1316 rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]); 1317 newIndexVal = 1318 rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]); 1319 } 1320 indexReplacementVals[foldedDims.value().front()] = newIndexVal; 1321 } 1322 1323 for (auto indexOp : indexOps) { 1324 auto dim = indexOp.dim(); 1325 rewriter.replaceOp(indexOp, indexReplacementVals[dim]); 1326 } 1327 } 1328 1329 /// Implementation of fusion with reshape operation by collapsing dimensions. 1330 static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims( 1331 GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims, 1332 OpOperand *fusableOpOperand, PatternRewriter &rewriter) { 1333 // Bail on trivial no-op cases. 1334 if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() || 1335 llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { 1336 return foldedDims.size() <= 1; 1337 })) 1338 return failure(); 1339 1340 CollapsingInfo collapsingInfo; 1341 if (failed(collapsingInfo.initialize(genericOp.getNumLoops(), 1342 foldedIterationDims))) { 1343 return rewriter.notifyMatchFailure( 1344 genericOp, "illegal to collapse specified dimensions"); 1345 } 1346 1347 // Get the iterator types for the operand. 1348 SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes( 1349 genericOp.iterator_types().getValue(), collapsingInfo); 1350 1351 // Get the indexing maps. 1352 auto indexingMaps = llvm::to_vector( 1353 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap map) { 1354 return getCollapsedOpIndexingMap(map, collapsingInfo); 1355 })); 1356 1357 Location loc = genericOp->getLoc(); 1358 1359 // Get the input operands. 1360 auto inputOperands = llvm::to_vector( 1361 llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) { 1362 return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo, 1363 rewriter); 1364 })); 1365 1366 // Get the output operands and result types. 1367 SmallVector<Type> resultTypes; 1368 SmallVector<Value> outputOperands; 1369 resultTypes.reserve(genericOp.getNumOutputs()); 1370 outputOperands.reserve(genericOp.getNumOutputs()); 1371 for (OpOperand *output : genericOp.getOutputOperands()) { 1372 Value newOutput = 1373 getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter); 1374 outputOperands.push_back(newOutput); 1375 resultTypes.push_back(newOutput.getType()); 1376 } 1377 1378 // Create the generic op. 1379 auto collapsedGenericOp = rewriter.create<linalg::GenericOp>( 1380 loc, resultTypes, inputOperands, outputOperands, indexingMaps, 1381 iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); 1382 Block *origOpBlock = &genericOp->getRegion(0).front(); 1383 Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front(); 1384 rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, 1385 collapsedOpBlock->getArguments()); 1386 1387 if (collapsedGenericOp.hasIndexSemantics()) { 1388 // Collect the loop range of the generic op. 1389 OpBuilder::InsertionGuard g(rewriter); 1390 rewriter.setInsertionPoint(collapsedGenericOp); 1391 SmallVector<Range> loopRanges = 1392 cast<LinalgOp>(genericOp.getOperation()) 1393 .createLoopRanges(rewriter, genericOp.getLoc()); 1394 assert(llvm::all_of(loopRanges, 1395 [](Range range) { 1396 return matchPattern(range.offset, m_Zero()) && 1397 matchPattern(range.stride, m_One()); 1398 }) && 1399 "expected all loop ranges to have zero start and unit stride"); 1400 SmallVector<Value> loopBound = llvm::to_vector( 1401 llvm::map_range(loopRanges, [](Range range) { return range.size; })); 1402 generateCollapsedIndexingRegion(loc, 1403 &collapsedGenericOp->getRegion(0).front(), 1404 collapsingInfo, loopBound, rewriter); 1405 } 1406 1407 // Insert expanding reshape for the result to get back the original result 1408 // type. 1409 SmallVector<Value> results; 1410 for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) { 1411 Value collapsedOpResult = 1412 collapsedGenericOp->getResult(originalResult.index()); 1413 auto originalResultType = 1414 originalResult.value().getType().cast<ShapedType>(); 1415 auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>(); 1416 if (collapsedOpResultType.getRank() != originalResultType.getRank()) { 1417 AffineMap indexingMap = 1418 genericOp.getTiedIndexingMapForResult(originalResult.value()); 1419 SmallVector<ReassociationIndices> reassociation = 1420 getOperandReassociation(indexingMap, collapsingInfo); 1421 Value result = rewriter.create<tensor::ExpandShapeOp>( 1422 loc, originalResultType, collapsedOpResult, reassociation); 1423 results.push_back(result); 1424 } else { 1425 results.push_back(collapsedOpResult); 1426 } 1427 } 1428 return results; 1429 } 1430 1431 namespace { 1432 1433 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by 1434 /// contracting dimensions of the loop. 1435 class FoldWithProducerReshapeOpByCollapsing 1436 : public OpRewritePattern<GenericOp> { 1437 public: 1438 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context, 1439 ControlFusionFn foldReshapes, 1440 PatternBenefit benefit = 1) 1441 : OpRewritePattern<GenericOp>(context, benefit), 1442 controlFoldingReshapes(std::move(foldReshapes)) {} 1443 1444 LogicalResult matchAndRewrite(GenericOp genericOp, 1445 PatternRewriter &rewriter) const override { 1446 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 1447 tensor::ExpandShapeOp reshapeOp = 1448 opOperand->get().getDefiningOp<tensor::ExpandShapeOp>(); 1449 if (!reshapeOp) 1450 continue; 1451 1452 SmallVector<ReassociationIndices> collapsableIterationDims = 1453 getCollapsableIterationSpaceDims(genericOp, opOperand, 1454 reshapeOp.getReassociationIndices()); 1455 if (collapsableIterationDims.empty() || 1456 !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) { 1457 continue; 1458 } 1459 1460 Optional<SmallVector<Value>> replacements = 1461 collapseGenericOpIterationDims(genericOp, collapsableIterationDims, 1462 opOperand, rewriter); 1463 if (!replacements) { 1464 return rewriter.notifyMatchFailure( 1465 genericOp, "failed to do the fusion by collapsing transformation"); 1466 } 1467 1468 rewriter.replaceOp(genericOp, replacements.getValue()); 1469 return success(); 1470 } 1471 return failure(); 1472 } 1473 1474 private: 1475 ControlFusionFn controlFoldingReshapes; 1476 }; 1477 } // namespace 1478 1479 //===---------------------------------------------------------------------===// 1480 // Methods and patterns that fuse constants with linalg.generic operations. 1481 //===---------------------------------------------------------------------===// 1482 1483 namespace { 1484 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not 1485 /// handle cases where the constant is not single-valued. 1486 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> { 1487 public: 1488 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1) 1489 : OpRewritePattern<GenericOp>(context, benefit) {} 1490 1491 LogicalResult matchAndRewrite(GenericOp genericOp, 1492 PatternRewriter &rewriter) const override { 1493 if (!genericOp.hasTensorSemantics()) 1494 return failure(); 1495 for (OpOperand *opOperand : genericOp.getInputOperands()) { 1496 Operation *def = opOperand->get().getDefiningOp(); 1497 Attribute constantAttr; 1498 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { 1499 { 1500 DenseElementsAttr splatAttr; 1501 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) && 1502 splatAttr.isSplat() && 1503 splatAttr.getType().getElementType().isIntOrFloat()) { 1504 constantAttr = splatAttr.getSplatValue<Attribute>(); 1505 return true; 1506 } 1507 } 1508 { 1509 IntegerAttr intAttr; 1510 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) { 1511 constantAttr = intAttr; 1512 return true; 1513 } 1514 } 1515 { 1516 FloatAttr floatAttr; 1517 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) { 1518 constantAttr = floatAttr; 1519 return true; 1520 } 1521 } 1522 return false; 1523 }; 1524 1525 auto resultValue = opOperand->get().dyn_cast<OpResult>(); 1526 if (!def || !resultValue || !isScalarOrSplatConstantOp(def)) 1527 continue; 1528 1529 // The operands and the indexing_maps of the fused operation the same as 1530 // the operands and indexing_maps of the generic operations with the 1531 // values at the constant index dropped. 1532 SmallVector<AffineMap> fusedIndexMaps; 1533 SmallVector<Value> fusedOperands; 1534 SmallVector<Location> fusedLocs{genericOp.getLoc()}; 1535 fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); 1536 fusedOperands.reserve(genericOp.getNumInputs()); 1537 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); 1538 for (OpOperand *inputOperand : genericOp.getInputOperands()) { 1539 if (inputOperand == opOperand) 1540 continue; 1541 Value inputValue = inputOperand->get(); 1542 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); 1543 fusedOperands.push_back(inputValue); 1544 fusedLocs.push_back(inputValue.getLoc()); 1545 } 1546 for (OpOperand *outputOperand : genericOp.getOutputOperands()) 1547 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); 1548 1549 // Check if the operation shapes to loops map is computable. 1550 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1551 return rewriter.notifyMatchFailure( 1552 genericOp, "fused op loop bound computation failed"); 1553 } 1554 1555 // Create a constant scalar value from the splat constant. 1556 Value scalarConstant = rewriter.create<arith::ConstantOp>( 1557 def->getLoc(), constantAttr, constantAttr.getType()); 1558 1559 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 1560 auto fusedOp = rewriter.create<GenericOp>( 1561 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), 1562 /*inputs=*/fusedOperands, 1563 /*outputs=*/outputOperands, 1564 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1565 genericOp.iterator_types(), 1566 /*doc=*/nullptr, 1567 /*library_call=*/nullptr); 1568 1569 // Map the block argument corresponding to the replaced argument with the 1570 // scalar constant. 1571 Region ®ion = genericOp->getRegion(0); 1572 Block &entryBlock = *region.begin(); 1573 BlockAndValueMapping mapping; 1574 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), 1575 scalarConstant); 1576 Region &fusedRegion = fusedOp->getRegion(0); 1577 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), 1578 mapping); 1579 rewriter.replaceOp(genericOp, fusedOp->getResults()); 1580 return success(); 1581 } 1582 return failure(); 1583 } 1584 }; 1585 1586 } // namespace 1587 1588 //===---------------------------------------------------------------------===// 1589 // Miscellaneous patterns that help fusion. 1590 //===---------------------------------------------------------------------===// 1591 1592 namespace { 1593 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if 1594 /// the value of the `outs` operand is not used within the op. This is only 1595 /// implemented for `linalg.generic` operations for now, but should hold for all 1596 /// linalg structured ops. 1597 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> { 1598 using OpRewritePattern<GenericOp>::OpRewritePattern; 1599 1600 LogicalResult matchAndRewrite(GenericOp op, 1601 PatternRewriter &rewriter) const override { 1602 rewriter.startRootUpdate(op); 1603 bool modifiedOutput = false; 1604 Location loc = op.getLoc(); 1605 for (OpOperand *opOperand : op.getOutputOperands()) { 1606 if (!op.payloadUsesValueFromOperand(opOperand)) { 1607 Value operandVal = opOperand->get(); 1608 auto operandType = operandVal.getType().dyn_cast<RankedTensorType>(); 1609 if (!operandType) 1610 continue; 1611 1612 // If outs is sparse, leave it to the sparse compiler. 1613 if (sparse_tensor::getSparseTensorEncoding(operandVal.getType())) 1614 continue; 1615 1616 // If outs is already an `init_tensor` operation, nothing to do. 1617 auto definingOp = operandVal.getDefiningOp<InitTensorOp>(); 1618 if (definingOp) 1619 continue; 1620 modifiedOutput = true; 1621 SmallVector<Value> dynamicDims; 1622 for (const auto &dim : llvm::enumerate(operandType.getShape())) { 1623 if (dim.value() != ShapedType::kDynamicSize) 1624 continue; 1625 dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>( 1626 loc, operandVal, dim.index())); 1627 } 1628 Value initTensor = rewriter.create<InitTensorOp>( 1629 loc, dynamicDims, operandType.getShape(), 1630 operandType.getElementType()); 1631 op->setOperand(opOperand->getOperandNumber(), initTensor); 1632 } 1633 } 1634 if (!modifiedOutput) { 1635 rewriter.cancelRootUpdate(op); 1636 return failure(); 1637 } 1638 rewriter.finalizeRootUpdate(op); 1639 return success(); 1640 } 1641 }; 1642 1643 /// Fold linalg.fill into linalg.generic 1644 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> { 1645 using OpRewritePattern<GenericOp>::OpRewritePattern; 1646 1647 LogicalResult matchAndRewrite(GenericOp genericOp, 1648 PatternRewriter &rewriter) const override { 1649 if (!genericOp.hasTensorSemantics()) 1650 return failure(); 1651 bool fillFound = false; 1652 Block &payload = genericOp.region().front(); 1653 for (OpOperand *opOperand : genericOp.getInputOperands()) { 1654 if (!genericOp.payloadUsesValueFromOperand(opOperand)) 1655 continue; 1656 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>(); 1657 if (!fillOp) 1658 continue; 1659 fillFound = true; 1660 payload.getArgument(opOperand->getOperandNumber()) 1661 .replaceAllUsesWith(fillOp.value()); 1662 } 1663 return success(fillFound); 1664 } 1665 }; 1666 } // namespace 1667 1668 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( 1669 RewritePatternSet &patterns, 1670 const ControlFusionFn &controlFoldingReshapes) { 1671 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(), 1672 controlFoldingReshapes); 1673 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), 1674 controlFoldingReshapes); 1675 } 1676 1677 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( 1678 RewritePatternSet &patterns, 1679 const ControlFusionFn &controlFoldingReshapes) { 1680 patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(), 1681 controlFoldingReshapes); 1682 } 1683 1684 void mlir::linalg::populateElementwiseOpsFusionPatterns( 1685 RewritePatternSet &patterns, 1686 const ControlFusionFn &controlElementwiseOpsFusion) { 1687 auto *context = patterns.getContext(); 1688 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion); 1689 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant, 1690 RemoveOutsDependency>(context); 1691 } 1692 1693 //===---------------------------------------------------------------------===// 1694 // Passes 1695 //===---------------------------------------------------------------------===// 1696 1697 namespace { 1698 1699 /// Pass that fuses generic ops on tensors. Used only for testing. 1700 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the 1701 // patterns added here heavily depends on the cost function used. Having an 1702 // opinionated pass of this form is not recommended. Deprecate this pass in 1703 // favor of test passes that check the functionality of each of the patterns 1704 // added here individually. 1705 struct LinalgElementwiseOpFusionPass 1706 : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> { 1707 void runOnOperation() override { 1708 Operation *op = getOperation(); 1709 MLIRContext *context = op->getContext(); 1710 RewritePatternSet patterns(context); 1711 1712 // Add folding with reshape by expansion patterns. 1713 ControlFusionFn defaultControlFn = [](const OpResult &producer, 1714 const OpOperand &consumer) { 1715 return producer.hasOneUse(); 1716 }; 1717 1718 // Add elementwise op fusion patterns. 1719 populateElementwiseOpsFusionPatterns(patterns, defaultControlFn); 1720 1721 populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn); 1722 1723 // Add the sparse tensor rewriting patterns. 1724 populateSparseTensorRewriting(patterns); 1725 1726 // General canonicalization patterns. 1727 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 1728 GenericOp::getCanonicalizationPatterns(patterns, context); 1729 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); 1730 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); 1731 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns( 1732 patterns); 1733 1734 // Add constant folding patterns. 1735 populateConstantFoldLinalgOperations(patterns, defaultControlFn); 1736 1737 // Use TopDownTraversal for compile time reasons 1738 GreedyRewriteConfig grc; 1739 grc.useTopDownTraversal = true; 1740 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), 1741 grc); 1742 } 1743 }; 1744 1745 } // namespace 1746 1747 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() { 1748 return std::make_unique<LinalgElementwiseOpFusionPass>(); 1749 } 1750