1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion on tensors operations pass. 10 // 11 //===----------------------------------------------------------------------===// 12 #include <utility> 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 28 using namespace mlir; 29 using namespace mlir::linalg; 30 31 //===---------------------------------------------------------------------===// 32 // Methods and patterns that fuse elementwise `linalg.generic` operations. 33 //===---------------------------------------------------------------------===// 34 35 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 36 /// the `producer` to use in the fused operation given the indexing map of the 37 /// result of the producer in the consumer. 38 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 39 OpOperand *producerOpOperand, AffineMap producerResultIndexMap, 40 AffineMap fusedConsumerArgIndexMap) { 41 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 42 // from consumer loop -> consumer arg tensor index/producer result tensor 43 // index. The fused loop is same as the consumer loop. For each producer arg 44 // the indexing map to be computed is a map from consumer loop -> producer 45 // arg tensor index. 46 // producerResultIndexMap is a map from producer loop -> tensor index. 47 // Compute the inverse to get map from tensor index -> producer loop. 48 // The inverse is a map from producer result tensor index -> producer loop. 49 AffineMap invProducerResultIndexMap = 50 inversePermutation(producerResultIndexMap); 51 assert(invProducerResultIndexMap && 52 "expected producer result indexing map to be invertible"); 53 54 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner()); 55 // argMap is a map from producer loop -> producer arg tensor index. 56 AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); 57 58 // Compose argMap with invProducerResultIndexMap to get a map from 59 // producer result tensor index -> producer arg tensor index. 60 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 61 62 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 63 // consumer loop/ fused loop -> producer arg tensor index. 64 return t1.compose(fusedConsumerArgIndexMap); 65 } 66 67 /// Conditions for elementwise fusion of generic operations. 68 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, 69 OpOperand *consumerOpOperand) { 70 // Producer and consumer must have tensor semantics. 71 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 72 return false; 73 74 // Verify that 75 // - the producer has all "parallel" iterator type. 76 if (producer.getNumParallelLoops() != producer.getNumLoops()) 77 return false; 78 79 // Only allow fusing the producer of an input operand for now. 80 // TODO: allow fusing the producer of an output operand. 81 if (!consumer.isInputTensor(consumerOpOperand)) 82 return false; 83 84 // Get the consumer index map. The number of results of the consumer index 85 // map must match the number of loops of the producer. 86 AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); 87 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 88 return false; 89 90 // Currently support only operations with single result. 91 if (producer.getNumOutputs() != 1) 92 return false; 93 94 // Finally the index_map for the result must be invertible. For now just 95 // verify it is a permutation. 96 AffineMap producerResultIndexMap = 97 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 98 if (!producerResultIndexMap.isPermutation()) 99 return false; 100 101 // Ensure that the fusion does not remove size information required to 102 // get the loop bounds. For non-reduction generics, this is trivially the 103 // case due to the output operand. For reductions, we need to check that after 104 // the fusion, each loop dimension has at least one input that defines it. 105 if ((consumer.getNumReductionLoops())) { 106 BitVector coveredDims(consumer.getNumLoops(), false); 107 108 auto addToCoveredDims = [&](AffineMap map) { 109 for (auto result : map.getResults()) 110 if (auto dimExpr = result.dyn_cast<AffineDimExpr>()) 111 coveredDims[dimExpr.getPosition()] = true; 112 }; 113 114 for (auto pair : 115 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) { 116 Value operand = std::get<0>(pair); 117 if (operand == consumerOpOperand->get()) 118 continue; 119 AffineMap operandMap = std::get<1>(pair); 120 addToCoveredDims(operandMap); 121 } 122 123 for (OpOperand *operand : producer.getInputOperands()) { 124 AffineMap newIndexingMap = 125 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 126 operand, producerResultIndexMap, consumerIndexMap); 127 addToCoveredDims(newIndexingMap); 128 } 129 if (!coveredDims.all()) 130 return false; 131 } 132 133 return true; 134 } 135 136 /// Generate the region of the fused tensor operation. The region of the fused 137 /// op must be empty. 138 static void 139 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, 140 AffineMap consumerToProducerLoopsMap, 141 OpOperand *consumerOpOperand, 142 unsigned nloops) { 143 auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp()); 144 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 145 // Build the region of the fused op. 146 Block &producerBlock = producer->getRegion(0).front(); 147 Block &consumerBlock = consumer->getRegion(0).front(); 148 Block *fusedBlock = new Block(); 149 fusedOp.region().push_back(fusedBlock); 150 BlockAndValueMapping mapper; 151 OpBuilder::InsertionGuard guard(rewriter); 152 rewriter.setInsertionPointToStart(fusedBlock); 153 154 // 2. Add an index operation for every fused loop dimension and use the 155 // `consumerToProducerLoopsMap` to map the producer indices. 156 if (producer.hasIndexSemantics()) { 157 // Add an index operation for every fused loop dimension. 158 unsigned numFusedOpLoops = 159 std::max(producer.getNumLoops(), consumer.getNumLoops()); 160 SmallVector<Value> fusedIndices; 161 fusedIndices.reserve(numFusedOpLoops); 162 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops), 163 std::back_inserter(fusedIndices), [&](uint64_t dim) { 164 return rewriter.create<IndexOp>(producer.getLoc(), dim); 165 }); 166 for (IndexOp indexOp : 167 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) { 168 Value newIndex = rewriter.create<mlir::AffineApplyOp>( 169 producer.getLoc(), 170 consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices); 171 mapper.map(indexOp.getResult(), newIndex); 172 } 173 } 174 // TODO: allow fusing the producer of an output operand. 175 assert(consumer.isInputTensor(consumerOpOperand) && 176 "expected producer of input operand"); 177 // 3. Consumer input operands up to consumerIdx (exclusive). 178 for (BlockArgument bbArg : consumerBlock.getArguments().take_front( 179 consumerOpOperand->getOperandNumber())) // input assumption. 180 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 181 182 // Replacing consumerIdx requires getting the cloned, yielded, value from 183 // the (cloned) producer block. This happens in step 9. 184 185 // 4. Splice in producer's input operands. 186 for (BlockArgument bbArg : 187 producerBlock.getArguments().take_front(producer.getNumInputs())) 188 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 189 190 // 4.b. Producer output operand/map that is fused needs to be mapped to the 191 // producer bbArg if it is an "initTensor" (i.e. its value is actually read). 192 assert(producer->getNumResults() == 1 && "expected single result producer"); 193 if (producer.isInitTensor(producer.getOutputOperand(0))) { 194 BlockArgument bbArg = producerBlock.getArguments() 195 .drop_front(producer.getNumInputs()) 196 // TODO: bbArg index of 197 .front(); 198 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 199 } 200 // 5. Remaining consumer's input operands (drop past index `consumerIdx`). 201 for (BlockArgument bbArg : 202 consumerBlock.getArguments() 203 .take_front(consumer.getNumInputs()) 204 .drop_front(consumerOpOperand->getOperandNumber() + 1)) 205 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 206 // 6. All of consumer's output operands. 207 for (BlockArgument bbArg : 208 consumerBlock.getArguments().take_back(consumer.getNumOutputs())) 209 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 210 // 7. All of producer's output operands except the one fused. 211 // TODO: allow fusion of multi-result producers. 212 assert(producer->getNumResults() == 1 && "expected single result producer"); 213 214 // 8. Clone all producer operations except for the yield and index operations 215 // to the fused operation. 216 for (auto &op : producerBlock.without_terminator()) { 217 if (!isa<IndexOp>(op)) 218 rewriter.clone(op, mapper); 219 } 220 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just 221 // forward the yield operand. 222 auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator()); 223 // TODO: allow fusion of multi-result producers. 224 assert(producer->getNumResults() == 1 && "expected single result producer"); 225 unsigned producerResultNumber = 0; 226 Value replacement = 227 mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); 228 // Sanity checks, if replacement is not already in the mapper then it must be 229 // produced outside. 230 if (replacement == yieldOp.getOperand(producerResultNumber)) { 231 if (auto bb = replacement.dyn_cast<BlockArgument>()) 232 assert(bb.getOwner() != &producerBlock && 233 "yielded block argument must have been mapped"); 234 else 235 assert(!producer->isAncestor(replacement.getDefiningOp()) && 236 "yielded value must have been mapped"); 237 } 238 mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), 239 replacement); 240 // 10. Clone operations from the consumer to the fused op. 241 for (auto &op : consumerBlock.getOperations()) 242 rewriter.clone(op, mapper); 243 244 // Sanity checks. 245 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && 246 "Ill-formed GenericOp region"); 247 } 248 249 static Optional<SmallVector<Value>> 250 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, 251 const ControlFusionFn &controlFn, 252 PatternRewriter &rewriter) { 253 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 254 if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || 255 !controlFn(producer->getResult(0), *consumerOpOperand)) 256 return llvm::None; 257 258 // TODO: allow fusing the producer of an output operand. 259 assert(consumer.isInputTensor(consumerOpOperand) && 260 "expected producer of input operand"); 261 262 // Compute the fused operands list and indexing maps. 263 SmallVector<Value> fusedOperands; 264 SmallVector<AffineMap> fusedIndexMaps; 265 fusedOperands.reserve(producer->getNumOperands() + 266 consumer->getNumOperands()); 267 fusedIndexMaps.reserve(producer->getNumOperands() + 268 consumer->getNumOperands()); 269 // In the following, numbering matches that of `generateFusedTensorOpRegion`. 270 // 3. Consumer input operands/maps up to consumerIdx (exclusive). 271 SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands(); 272 SmallVector<OpOperand *>::iterator it = 273 llvm::find(consumerInputs, consumerOpOperand); 274 assert(it != consumerInputs.end() && "expected to find the consumer operand"); 275 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { 276 fusedOperands.push_back(opOperand->get()); 277 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 278 } 279 // 4. Splice in producer's input operands/maps. 280 assert(producer->getNumResults() == 1 && "expected single result producer"); 281 AffineMap producerResultIndexMap = 282 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 283 for (OpOperand *opOperand : producer.getInputOperands()) { 284 fusedOperands.push_back(opOperand->get()); 285 // Compute indexing maps for the producer args in the fused operation. 286 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 287 opOperand, producerResultIndexMap, 288 consumer.getTiedIndexingMap(consumerOpOperand)); 289 fusedIndexMaps.push_back(map); 290 } 291 // 4.b. Producer output operand/map that is fused needs to be passed if it is 292 // an "initTensor" (i.e. its value is actually read). 293 assert(producer->getNumResults() == 1 && "expected single result producer"); 294 if (producer.isInitTensor(producer.getOutputOperand(0))) { 295 fusedOperands.push_back(producer.getOutputOperand(0)->get()); 296 // Compute indexing maps for the producer args in the fused operation. 297 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 298 producer.getOutputOperand(0), producerResultIndexMap, 299 consumer.getTiedIndexingMap(consumerOpOperand)); 300 fusedIndexMaps.push_back(map); 301 } 302 // 5. Remaining consumer's input operands/maps (drop past index 303 // `consumerIdx`). 304 for (OpOperand *opOperand : 305 llvm::make_range(std::next(it), consumerInputs.end())) { 306 fusedOperands.push_back(opOperand->get()); 307 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 308 } 309 // 6. All of consumer's output operands (skip operands: added by the builder). 310 for (OpOperand *opOperand : consumer.getOutputOperands()) 311 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 312 // 7. All of producer's output operands/maps except the one fused. 313 // TODO: allow fusion of multi-result producers. 314 assert(producer->getNumResults() == 1 && "expected single result producer"); 315 316 // Generate the fused op. 317 SmallVector<Value> consumerOutputs = consumer.getOutputOperands(); 318 auto fusedOp = rewriter.create<GenericOp>( 319 consumer.getLoc(), consumer->getResultTypes(), 320 /*inputs=*/fusedOperands, 321 // TODO: handle outputs. 322 consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 323 consumer.iterator_types(), 324 /*doc=*/nullptr, 325 /*library_call=*/nullptr); 326 if (!fusedOp.getShapesToLoopsMap()) { 327 // Fused op has invalid indexing maps. Typically this means something is off 328 // in the input, but going ahead here would result in verification errors. 329 // So cleanup and abort. 330 rewriter.eraseOp(fusedOp); 331 return llvm::None; 332 } 333 334 // Construct an AffineMap from consumer loops to producer loops. 335 // consumer loop -> tensor index 336 AffineMap consumerResultIndexMap = 337 consumer.getTiedIndexingMap(consumerOpOperand); 338 // tensor index -> producer loop 339 AffineMap invProducerResultIndexMap = 340 inversePermutation(producerResultIndexMap); 341 assert(invProducerResultIndexMap && 342 "expected producer result indexig map to be invertible"); 343 // consumer loop -> producer loop 344 AffineMap consumerToProducerLoopsMap = 345 invProducerResultIndexMap.compose(consumerResultIndexMap); 346 347 generateFusedElementwiseOpRegion(rewriter, fusedOp, 348 consumerToProducerLoopsMap, 349 consumerOpOperand, consumer.getNumLoops()); 350 return SmallVector<Value>(fusedOp->getResults()); 351 } 352 353 static Optional<SmallVector<Value>> 354 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, 355 GenericOp producer, const ControlFusionFn &controlFn) { 356 if (producer->getNumResults() != 1) 357 return llvm::None; 358 359 return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, 360 rewriter); 361 } 362 363 namespace { 364 /// Patterns to fuse a generic op, with the producer of its operands. 365 class FuseElementwiseOps : public OpRewritePattern<GenericOp> { 366 public: 367 FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun, 368 PatternBenefit benefit = 1) 369 : OpRewritePattern<GenericOp>(context, benefit), 370 controlFn(std::move(fun)) {} 371 372 LogicalResult matchAndRewrite(GenericOp genericOp, 373 PatternRewriter &rewriter) const override { 374 // Find the first operand that is defined by another generic op on tensors. 375 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 376 auto producer = 377 dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp()); 378 if (!producer || !producer.hasTensorSemantics()) 379 continue; 380 Optional<SmallVector<Value>> fusedOpResults = 381 fuseElementwiseOps(rewriter, opOperand, producer, controlFn); 382 if (fusedOpResults) { 383 rewriter.replaceOp(genericOp, *fusedOpResults); 384 return success(); 385 } 386 } 387 return failure(); 388 } 389 390 private: 391 ControlFusionFn controlFn; 392 }; 393 } // namespace 394 395 //===---------------------------------------------------------------------===// 396 // Methods and patterns that fuse reshape ops with elementwise operations by 397 // expanding the dimensionality of the elementwise operations. 398 //===---------------------------------------------------------------------===// 399 400 /// Conditions for folding a generic operation with a reshape op by expanding 401 /// the iteration space dimensionality for tensor operations. These are 402 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the 403 /// following fusion pattern. 404 /// 405 /// Consider 406 /// 407 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>) 408 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 409 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 410 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] 411 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] 412 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 413 /// 414 /// The reshape can be folded into the `genericOp` if its loop dimensionality 415 /// is increased to match the result (operand) of the tensor.expand_shape. 416 /// The indexing_map of the fused tensor in the `genericOp` and the 417 /// reassociation map helps compute the indexing maps of the modified op. 418 /// For the above example, based on the reassociation map it 419 /// can be concluded that 420 /// 421 /// - The loop used to access the first dimension of the fused tensor is split 422 /// into two. 423 /// - The loop used to access the second dimension of the fused tensor is kept 424 /// as is. 425 /// - The loop used to access the third dimension of the fused tensor is split 426 /// into three. 427 /// 428 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified 429 /// op, then 430 /// 431 /// d0 -> e0, e1 432 /// d1 -> e2, e3, e4 433 /// d2 -> e5 434 /// 435 /// substituting this, the generic op can be rewritten as 436 /// 437 /// %d = linalg.generic ins(%0, %1 : ) 438 /// indexing_maps = 439 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, 440 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, 441 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] 442 /// 443 /// Since operands to the linalg generic are now 5D, reshapes can be introduced 444 /// to make it consistent 445 /// 446 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]] 447 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 448 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]] 449 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 450 /// 451 /// The added reshapes are again expanding patterns, so they will get fused 452 /// with its producers if possible. 453 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, 454 OpOperand *fusableOpOperand) { 455 // Is fusable only if: 456 // - All the indexing maps for operands and results are projected 457 // permutations. 458 // - The fused tensor is not a scalar. 459 // - All the loops are parallel loops. 460 return genericOp.hasTensorSemantics() && 461 llvm::all_of(genericOp.indexing_maps().getValue(), 462 [](Attribute attr) { 463 return attr.cast<AffineMapAttr>() 464 .getValue() 465 .isProjectedPermutation(); 466 }) && 467 genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && 468 llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { 469 return attr.cast<StringAttr>().getValue() == 470 getParallelIteratorTypeName(); 471 }); 472 } 473 474 namespace { 475 /// Information needed to expand a generic operation to fold the reshape with 476 /// it. 477 class ExpansionInfo { 478 public: 479 // Computes the mapping from original dimensions of the op to the dimensions 480 // of the expanded op given the `indexingMap` of the fused operand/result of 481 // the generic op, the `reassocationMaps` of the reshape op and the shape of 482 // the expanded op. 483 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, 484 ArrayRef<AffineMap> reassociationMaps, 485 ArrayRef<int64_t> expandedShape, 486 ArrayRef<int64_t> collapsedShape, 487 PatternRewriter &rewriter); 488 unsigned getOrigOpNumDims() const { return reassociation.size(); } 489 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } 490 ReassociationIndicesRef getExpandedDims(unsigned i) const { 491 return reassociation[i]; 492 } 493 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { 494 return expandedShapeMap[i]; 495 } 496 ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; } 497 498 private: 499 /// Reassociation from the dimensions in the original operation to the 500 /// dimension of the expanded operation. 501 SmallVector<ReassociationIndices> reassociation; 502 /// Mapping from extent of loops in the original operation, to the extent of 503 /// loops in the expanded operation. 504 SmallVector<SmallVector<int64_t>> expandedShapeMap; 505 /// Extent of the loop in the original operation. 506 SmallVector<int64_t> originalLoopExtent; 507 unsigned expandedOpNumDims; 508 }; 509 } // namespace 510 511 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, 512 OpOperand *fusableOpOperand, 513 ArrayRef<AffineMap> reassociationMaps, 514 ArrayRef<int64_t> expandedShape, 515 ArrayRef<int64_t> collapsedShape, 516 PatternRewriter &rewriter) { 517 if (reassociationMaps.empty()) 518 return failure(); 519 AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); 520 521 SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges(); 522 originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end()); 523 524 reassociation.clear(); 525 expandedShapeMap.clear(); 526 // Compute the number of dimension in the expanded op that correspond to each 527 // dimension of the original op. 528 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1); 529 expandedShapeMap.resize(fusedIndexMap.getNumDims()); 530 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { 531 unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition(); 532 AffineMap foldedDims = reassociationMaps[resultExpr.index()]; 533 numExpandedDims[pos] = foldedDims.getNumResults(); 534 ArrayRef<int64_t> shape = 535 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); 536 expandedShapeMap[pos].assign(shape.begin(), shape.end()); 537 } 538 // The remaining dimensions remain the same. 539 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims())) 540 if (expandedShapeMap[i].empty()) 541 expandedShapeMap[i] = {originalLoopExtent[i]}; 542 543 // Compute reassociation map from the original op to the expanded op. 544 unsigned sum = 0; 545 reassociation.reserve(fusedIndexMap.getNumDims()); 546 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) { 547 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value()); 548 reassociation.emplace_back(seq.begin(), seq.end()); 549 sum += numFoldedDim.value(); 550 } 551 expandedOpNumDims = sum; 552 return success(); 553 } 554 555 /// Epanding the body of a linalg operation requires adaptations of the accessed 556 /// loop indices. Specifically, access of indices in the original operation need 557 /// to be replaced with linearizations of indices in the expanded op. That 558 /// requires the shape of the expanded dimensions to be static (at least all but 559 /// the most significant). For now check that these are all statically sized. 560 /// Note that this could be extended to handle dynamic case, but the 561 /// implementation below uses `affine.apply` which seems to have issues when the 562 /// shapes are not static. 563 static LogicalResult isGenericOpExpandable(GenericOp genericOp, 564 const ExpansionInfo &expansionInfo, 565 PatternRewriter &rewriter) { 566 if (!genericOp.hasIndexSemantics()) 567 return success(); 568 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 569 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 570 if (expandedShape.size() == 1) 571 continue; 572 for (int64_t shape : expandedShape.drop_front()) { 573 if (ShapedType::isDynamic(shape)) { 574 return rewriter.notifyMatchFailure( 575 genericOp, "cannot expand due to index semantics and dynamic dims"); 576 } 577 } 578 } 579 return success(); 580 } 581 582 /// Return the indexing map to use in the expanded op for a given the 583 /// `indexingMap` of the original operation. 584 static AffineMap 585 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, 586 const ExpansionInfo &expansionInfo) { 587 SmallVector<AffineExpr> newExprs; 588 for (AffineExpr expr : indexingMap.getResults()) { 589 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 590 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>( 591 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { 592 return builder.getAffineDimExpr(static_cast<unsigned>(v)); 593 })); 594 newExprs.append(expandedExprs.begin(), expandedExprs.end()); 595 } 596 return AffineMap::get(expansionInfo.getExpandedOpNumDims(), 597 indexingMap.getNumSymbols(), newExprs, 598 builder.getContext()); 599 } 600 601 /// Return the type of the operand/result to use in the expanded op given the 602 /// type in the original op. 603 static RankedTensorType getExpandedType(RankedTensorType originalType, 604 AffineMap indexingMap, 605 const ExpansionInfo &expansionInfo) { 606 SmallVector<int64_t> expandedShape; 607 for (AffineExpr expr : indexingMap.getResults()) { 608 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 609 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); 610 expandedShape.append(dimExpansion.begin(), dimExpansion.end()); 611 } 612 return RankedTensorType::get(expandedShape, originalType.getElementType()); 613 } 614 615 /// Returns the reassociation maps to use in the `tensor.expand_shape` 616 /// operation to convert the operands of the original operation to operands of 617 /// the expanded operation. The same method is used to compute the 618 /// `tensor.collapse_shape` used to collapse the result of the expanded 619 /// op to get the value that can replace all uses of the results of the original 620 /// op. 621 static SmallVector<ReassociationIndices> 622 getReassociationForExpansion(AffineMap indexingMap, 623 const ExpansionInfo &expansionInfo) { 624 SmallVector<ReassociationIndices> reassociation; 625 unsigned numReshapeDims = 0; 626 for (AffineExpr expr : indexingMap.getResults()) { 627 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 628 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); 629 SmallVector<int64_t, 2> indices = llvm::to_vector<2>( 630 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims)); 631 reassociation.emplace_back(std::move(indices)); 632 numReshapeDims += numExpandedDims; 633 } 634 return reassociation; 635 } 636 637 /// Update the body of an expanded linalg operation having index semantics. The 638 /// indices of the original operation need to be recovered by linearizing the 639 /// indices of the correspoding dimensions of the expanded operation. For now it 640 /// is assumed that the shapes of the expanded operation needed for 641 /// linearization are static. 642 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, 643 Location loc, Region &fusedRegion, 644 const ExpansionInfo &expansionInfo) { 645 // Replace the original indices by the linearization of the expanded indices. 646 for (IndexOp indexOp : 647 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) { 648 ArrayRef<int64_t> expandedDims = 649 expansionInfo.getExpandedDims(indexOp.dim()); 650 assert(!expandedDims.empty() && "expected valid expansion info"); 651 652 // Skip index operations that are not affected by the expansion. 653 if (expandedDims.size() == 1 && 654 expandedDims.front() == (int64_t)indexOp.dim()) 655 continue; 656 657 // Linearize the expanded indices of the original index dimension. 658 OpBuilder::InsertionGuard guard(rewriter); 659 rewriter.setInsertionPointAfter(indexOp); 660 ArrayRef<int64_t> expandedDimsShape = 661 expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front(); 662 SmallVector<Value> expandedIndices; 663 expandedIndices.reserve(expandedDims.size() - 1); 664 llvm::transform( 665 expandedDims.drop_front(), std::back_inserter(expandedIndices), 666 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); 667 Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front()); 668 for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { 669 assert(!ShapedType::isDynamic(std::get<0>(it))); 670 AffineExpr idx, acc; 671 bindDims(rewriter.getContext(), idx, acc); 672 newIndex = rewriter.create<AffineApplyOp>( 673 indexOp.getLoc(), idx + acc * std::get<0>(it), 674 ValueRange{std::get<1>(it), newIndex}); 675 } 676 rewriter.replaceOp(indexOp, newIndex); 677 } 678 } 679 680 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op 681 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes 682 /// that those conditions have been satisfied. 683 static Optional<SmallVector<Value>> 684 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, 685 OpOperand *fusableOpOperand, 686 PatternRewriter &rewriter) { 687 assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) && 688 "preconditions for fuse operation failed"); 689 // Check if reshape is expanding or collapsing. 690 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp); 691 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp); 692 bool isExpanding = (expandingReshapeOp != nullptr); 693 RankedTensorType expandedType = isExpanding 694 ? expandingReshapeOp.getResultType() 695 : collapsingReshapeOp.getSrcType(); 696 RankedTensorType collapsedType = isExpanding 697 ? expandingReshapeOp.getSrcType() 698 : collapsingReshapeOp.getResultType(); 699 700 ExpansionInfo expansionInfo; 701 if (failed(expansionInfo.compute( 702 genericOp, fusableOpOperand, 703 isExpanding ? expandingReshapeOp.getReassociationMaps() 704 : collapsingReshapeOp.getReassociationMaps(), 705 expandedType.getShape(), collapsedType.getShape(), rewriter))) 706 return llvm::None; 707 708 if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) 709 return llvm::None; 710 711 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( 712 llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap m) { 713 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); 714 })); 715 716 SmallVector<Value> expandedOpOperands; 717 expandedOpOperands.reserve(genericOp.getNumInputs()); 718 for (OpOperand *opOperand : genericOp.getInputOperands()) { 719 if (opOperand == fusableOpOperand) { 720 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc() 721 : collapsingReshapeOp.getSrc()); 722 continue; 723 } 724 if (genericOp.isInputTensor(opOperand)) { 725 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 726 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 727 RankedTensorType expandedOperandType = 728 getExpandedType(opOperandType, indexingMap, expansionInfo); 729 if (expandedOperandType != opOperand->get().getType()) { 730 // Reshape the operand to get the right type. 731 SmallVector<ReassociationIndices> reassociation = 732 getReassociationForExpansion(indexingMap, expansionInfo); 733 if (failed(reshapeLikeShapesAreCompatible( 734 [&](const Twine &msg) { 735 return rewriter.notifyMatchFailure(genericOp, msg); 736 }, 737 opOperandType.getShape(), expandedOperandType.getShape(), 738 reassociation, 739 /*isExpandingReshape=*/true))) 740 return llvm::None; 741 expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>( 742 genericOp.getLoc(), expandedOperandType, opOperand->get(), 743 reassociation)); 744 continue; 745 } 746 } 747 expandedOpOperands.push_back(opOperand->get()); 748 } 749 750 Location loc = genericOp.getLoc(); 751 SmallVector<Value> outputs; 752 for (OpOperand *opOperand : genericOp.getOutputOperands()) { 753 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 754 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 755 RankedTensorType expandedOutputType = 756 getExpandedType(opOperandType, indexingMap, expansionInfo); 757 if (expandedOutputType != opOperand->get().getType()) { 758 SmallVector<ReassociationIndices> reassociation = 759 getReassociationForExpansion(indexingMap, expansionInfo); 760 if (failed(reshapeLikeShapesAreCompatible( 761 [&](const Twine &msg) { 762 return rewriter.notifyMatchFailure(genericOp, msg); 763 }, 764 opOperandType.getShape(), expandedOutputType.getShape(), 765 reassociation, 766 /*isExpandingReshape=*/true))) 767 return llvm::None; 768 outputs.push_back(rewriter.create<tensor::ExpandShapeOp>( 769 genericOp.getLoc(), expandedOutputType, opOperand->get(), 770 reassociation)); 771 } 772 } 773 774 // The iterator types of the expanded op are all parallel. 775 SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(), 776 getParallelIteratorTypeName()); 777 778 TypeRange resultTypes = ValueRange(outputs).getTypes(); 779 auto fusedOp = 780 rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes, 781 /*inputs=*/expandedOpOperands, outputs, 782 expandedOpIndexingMaps, iteratorTypes); 783 Region &fusedRegion = fusedOp->getRegion(0); 784 Region &originalRegion = genericOp->getRegion(0); 785 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); 786 787 // Update the index accesses after the expansion. 788 updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); 789 790 // Reshape the result values to their original shape if this is a collapsing 791 // reshape folded into its consumer. 792 SmallVector<Value> resultVals; 793 for (OpResult opResult : genericOp->getOpResults()) { 794 int64_t resultNumber = opResult.getResultNumber(); 795 if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { 796 SmallVector<ReassociationIndices> reassociation = 797 getReassociationForExpansion( 798 genericOp.getTiedIndexingMap( 799 genericOp.getOutputOperand(resultNumber)), 800 expansionInfo); 801 resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>( 802 genericOp.getLoc(), opResult.getType(), 803 fusedOp->getResult(resultNumber), reassociation)); 804 } else { 805 resultVals.push_back(fusedOp->getResult(resultNumber)); 806 } 807 } 808 // Assuming a single result. 809 return resultVals; 810 } 811 812 namespace { 813 814 /// Pattern to fuse a tensor.collapse_shape op with its consumer generic op, 815 /// when the reshape op is collapsing dimensions. The dimensionality of the loop 816 /// in the consumer is expanded. 817 class FoldWithProducerReshapeOpByExpansion 818 : public OpRewritePattern<GenericOp> { 819 public: 820 FoldWithProducerReshapeOpByExpansion(MLIRContext *context, 821 ControlFusionFn foldReshapes, 822 PatternBenefit benefit = 1) 823 : OpRewritePattern<GenericOp>(context, benefit), 824 controlFoldingReshapes(std::move(foldReshapes)) {} 825 826 LogicalResult matchAndRewrite(GenericOp genericOp, 827 PatternRewriter &rewriter) const override { 828 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 829 tensor::CollapseShapeOp reshapeOp = 830 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>(); 831 if (!reshapeOp) 832 continue; 833 // Fold only if 834 // - The tensor reshape op is folding. 835 // - All constraints of fusing with reshape by expansion are met. 836 if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || 837 (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) 838 continue; 839 840 Optional<SmallVector<Value>> replacementValues = 841 fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); 842 if (!replacementValues) 843 return failure(); 844 rewriter.replaceOp(genericOp, *replacementValues); 845 return success(); 846 } 847 return failure(); 848 } 849 850 private: 851 ControlFusionFn controlFoldingReshapes; 852 }; 853 854 /// Pattern to fold a tensor.expand_shape op with its producer generic op 855 /// by expanding the dimensionality of the loop in the producer op. 856 struct FoldReshapeWithGenericOpByExpansion 857 : public OpRewritePattern<tensor::ExpandShapeOp> { 858 859 FoldReshapeWithGenericOpByExpansion(MLIRContext *context, 860 ControlFusionFn foldReshapes, 861 PatternBenefit benefit = 1) 862 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), 863 controlFoldingReshapes(std::move(foldReshapes)) {} 864 865 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, 866 PatternRewriter &rewriter) const override { 867 // Fold only if all constraints of fusing with reshape by expansion are met. 868 GenericOp producer = reshapeOp.getSrc().getDefiningOp<GenericOp>(); 869 if (!producer || producer.getNumOutputs() != 1 || 870 !isFusableWithReshapeByDimExpansion(producer, 871 producer.getOutputOperand(0)) || 872 !controlFoldingReshapes(producer->getResult(0), 873 reshapeOp->getOpOperand(0))) 874 return failure(); 875 Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion( 876 producer, reshapeOp, producer.getOutputOperand(0), rewriter); 877 if (!replacementValues) 878 return failure(); 879 rewriter.replaceOp(reshapeOp, *replacementValues); 880 return success(); 881 } 882 883 private: 884 ControlFusionFn controlFoldingReshapes; 885 }; 886 } // namespace 887 888 //===---------------------------------------------------------------------===// 889 // Methods and patterns to fuse reshape with linalg.generic operations by 890 // contraction of dimensions. 891 //===---------------------------------------------------------------------===// 892 893 /// For a given list of indices in the range of the `indexingMap` that are 894 /// folded, return the indices of the corresponding domain. Return `llvm::None` 895 /// on failure. Ensures that all the elements of the returned reassociation are 896 /// distinct. 897 static ReassociationIndices 898 getDomainReassociation(AffineMap indexingMap, 899 ReassociationIndicesRef rangeReassociation) { 900 assert(indexingMap.isProjectedPermutation() && 901 "expected projected permutation"); 902 903 ReassociationIndices domainReassociation = llvm::to_vector<4>( 904 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t { 905 return indexingMap.getResults()[pos] 906 .cast<AffineDimExpr>() 907 .getPosition(); 908 })); 909 // The projected permutation semantics ensures that there is no repetition of 910 // the domain indices. 911 return domainReassociation; 912 } 913 914 /// For a given `dimSequence`, check if the sequence is conserved in the 915 /// `indexingMap`. `indexingMap` is expected to be a projected permutation. 916 /// Non-existence of the sequence returns true as well. 917 static bool isDimSequencePreserved(AffineMap indexingMap, 918 ReassociationIndicesRef dimSequence) { 919 assert(!dimSequence.empty() && 920 "expected non-empty list for dimension sequence"); 921 assert(indexingMap.isProjectedPermutation() && 922 "expected indexing map to be projected permutation"); 923 924 llvm::SmallDenseSet<unsigned, 4> sequenceElements; 925 sequenceElements.insert(dimSequence.begin(), dimSequence.end()); 926 927 unsigned dimSequenceStart = dimSequence[0]; 928 for (const auto &expr : enumerate(indexingMap.getResults())) { 929 unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition(); 930 // 1. Check if this start of the sequence. 931 if (dimInMapStart == dimSequenceStart) { 932 if (expr.index() + dimSequence.size() > indexingMap.getNumResults()) 933 return false; 934 // 1a. Check if sequence is preserved. 935 for (const auto &dimInSequence : enumerate(dimSequence)) { 936 unsigned dimInMap = 937 indexingMap.getResult(expr.index() + dimInSequence.index()) 938 .cast<AffineDimExpr>() 939 .getPosition(); 940 if (dimInMap != dimInSequence.value()) 941 return false; 942 } 943 // Found the sequence. Projected permutation 944 // enforces that all AffineDimExprs in the result are unique, so no 945 // further checks are needed. 946 return true; 947 } 948 // 2. If position in the expr (which is of type AffineDimExpr) is part 949 // of sequence, return false here. This implies the entire sequence does not 950 // exist in the indexing map. 951 if (sequenceElements.count(dimInMapStart)) 952 return false; 953 } 954 // 3. No element of sequence found. Return true. 955 return true; 956 } 957 958 // Return the list of dimensions of the iteration domain that can be 959 // collapsed to allow for fusion with the a producer that is an expand_shape 960 // operation. If all dimensions created by expansion can be collapsed in the 961 // iteration space then the reshape is defunct. 962 // 963 // Example: 964 // 965 // ```mlir 966 // #map = affine_map<(d0, d1) -> (d0, d1)> 967 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 968 // %2 = linalg.init_tensor [..] : tensor<?x4xf32> 969 // %3 = linalg.generic { 970 // indexing_maps = [#map, #map], 971 // iterator_types = ["parallel" ,"parallel"]} 972 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. } 973 // ``` 974 // 975 // can be fused by collapsing the dimensions of the iteration space. 976 // 977 // ```mlir 978 // #map = affine_map<(d0) -> (d0)> 979 // %2 = linalg.init_tensor [..] : tensor<?xf32> 980 // %3 = linalg.generic { 981 // indexing_maps = [#map, #map], 982 // iterator_types = ["parallel"]} 983 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. } 984 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 985 // ``` 986 // 987 // In the following example, 988 // 989 // ```mlir 990 // #map0 = affine_map<(d0, d1) -> (d0, d1)> 991 // #map1 = affine_map<(d0, d1) -> (d1, d0)> 992 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 993 // %2 = linalg.init_tensor [..] : tensor<4x?xf32> 994 // %2 = linalg.generic { 995 // indexing_maps = [#map0, #map1], 996 // iterator_types = ["parallel" ,"parallel"]} 997 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. } 998 // ``` 999 // 1000 // the reshape cannot be fused with the generic op by collapsing the op 1001 // dimensions since the indexing maps will have to contain mods and divs 1002 // to preserve the accesses pattern. When no dimensions of the iteration 1003 // space are collapsable and empty vector is returned. 1004 static SmallVector<ReassociationIndices> 1005 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, 1006 ArrayRef<ReassociationIndices> reassociation) { 1007 // Some basic checks for this fusion to be valid. 1008 if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1) 1009 return {}; 1010 1011 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) { 1012 return map.isProjectedPermutation(); 1013 })) { 1014 return {}; 1015 } 1016 1017 // Compute all the loops with the reduction iterator types. 1018 SmallVector<int64_t> reductionDims; 1019 for (const auto &iteratorType : llvm::enumerate(genericOp.iterator_types())) { 1020 if (isReductionIterator(iteratorType.value())) { 1021 reductionDims.push_back(iteratorType.index()); 1022 } 1023 } 1024 1025 llvm::SmallDenseSet<unsigned, 4> processedIterationDims; 1026 AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand); 1027 auto iteratorTypes = genericOp.iterator_types().getValue(); 1028 SmallVector<ReassociationIndices> iterationSpaceReassociation; 1029 for (ReassociationIndicesRef foldedRangeDims : reassociation) { 1030 assert(!foldedRangeDims.empty() && "unexpected empty reassociation"); 1031 1032 // Ignore dims that are not folded. 1033 if (foldedRangeDims.size() == 1) 1034 continue; 1035 1036 ReassociationIndices foldedIterationSpaceDims = 1037 getDomainReassociation(indexingMap, foldedRangeDims); 1038 1039 // Check that the folded iteration dims do not contain already processed 1040 // dims. 1041 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) { 1042 return processedIterationDims.count(dim); 1043 })) 1044 continue; 1045 1046 // Check that all folded iterator types are all parallel or all reductions. 1047 Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]]; 1048 if (!isParallelIterator(startIteratorType) && 1049 !isReductionIterator(startIteratorType)) 1050 continue; 1051 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) { 1052 return iteratorTypes[dim] != startIteratorType; 1053 })) 1054 continue; 1055 1056 // If the folded dimensions correspond to a "reduction" iterator type, 1057 // the folded dimensions need to be "in-order". Strictly speaking this is 1058 // not necessary, for reductions that are associative and commutative, but 1059 // using a more strict definition of reduction for now. 1060 if (isReductionIterator(startIteratorType)) { 1061 bool isContiguous = false; 1062 for (const auto &startDim : llvm::enumerate(reductionDims)) { 1063 // Move window in `reductionDims` to start of the folded iteration dims. 1064 if (startDim.value() != foldedIterationSpaceDims[0]) 1065 continue; 1066 // If sizes doesnt match, trivial not contiguous. This condition should 1067 // not be hit. 1068 if (startDim.index() + foldedIterationSpaceDims.size() > 1069 reductionDims.size()) 1070 break; 1071 // Check that the contiguity is maintained. 1072 isContiguous = true; 1073 for (const auto &foldedDim : 1074 llvm::enumerate(foldedIterationSpaceDims)) { 1075 if (reductionDims[foldedDim.index() + startDim.index()] != 1076 foldedDim.value()) { 1077 isContiguous = false; 1078 break; 1079 } 1080 } 1081 break; 1082 } 1083 if (!isContiguous) 1084 continue; 1085 } 1086 1087 // Check that the sequence is preserved in all indexing maps. 1088 if (llvm::any_of(genericOp.getIndexingMapsArray(), 1089 [&](AffineMap indexingMap) { 1090 return !isDimSequencePreserved(indexingMap, 1091 foldedIterationSpaceDims); 1092 })) 1093 continue; 1094 1095 processedIterationDims.insert(foldedIterationSpaceDims.begin(), 1096 foldedIterationSpaceDims.end()); 1097 iterationSpaceReassociation.emplace_back( 1098 std::move(foldedIterationSpaceDims)); 1099 } 1100 1101 return iterationSpaceReassociation; 1102 } 1103 1104 /// Helper class to carry state while collapsing the `linalg.generic` op. 1105 namespace { 1106 class CollapsingInfo { 1107 public: 1108 LogicalResult initialize(unsigned origNumLoops, 1109 ArrayRef<ReassociationIndices> foldedIterationDims) { 1110 llvm::SmallDenseSet<int64_t, 4> processedDims; 1111 // Find all the dims that are folded. 1112 for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) { 1113 if (foldedIterationDim.empty()) 1114 continue; 1115 // If the folded dims contain dims already folded, that's illegal 1116 // specification. Repetition within a list is also illegal. 1117 for (auto dim : foldedIterationDim) { 1118 if (dim >= origNumLoops) 1119 return failure(); 1120 if (processedDims.count(dim)) 1121 return failure(); 1122 processedDims.insert(dim); 1123 } 1124 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(), 1125 foldedIterationDim.end()); 1126 } 1127 if (processedDims.size() > origNumLoops) 1128 return failure(); 1129 1130 // Add all the preserved dims of the original op as single 1131 // elements to `collapsedOpToOrigOpIterationDim`. 1132 for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) { 1133 if (processedDims.count(dim)) 1134 continue; 1135 collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim}); 1136 } 1137 1138 llvm::sort(collapsedOpToOrigOpIterationDim, 1139 [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) { 1140 return lhs[0] < rhs[0]; 1141 }); 1142 origOpToCollapsedOpIterationDim.resize(origNumLoops); 1143 for (const auto &foldedDims : 1144 llvm::enumerate(collapsedOpToOrigOpIterationDim)) { 1145 for (const auto &dim : enumerate(foldedDims.value())) 1146 origOpToCollapsedOpIterationDim[dim.value()] = 1147 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index()); 1148 } 1149 return success(); 1150 } 1151 1152 /// Return mapping from collapsed loop domain to original loop domain. 1153 ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const { 1154 return collapsedOpToOrigOpIterationDim; 1155 } 1156 1157 /// Return mapping from original loop domain to collapsed loop domain. The 1158 /// mapping is a pair. First value is the dimension in the collapsed loop that 1159 /// the original loop is mapped to. Second is the relative position in folded 1160 /// list of this domain. For example if the original loop domain is 3D, and 1161 /// the collapsed loop domain is folding all of it, i.e. 1162 /// 1163 /// ``` 1164 /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]` 1165 /// ``` 1166 /// 1167 /// then 1168 /// 1169 /// ``` 1170 /// origOpToCollapsedOpMapping[0] = {0, 0}; 1171 /// origOpToCollapsedOpMapping[1] = {0, 1}; 1172 /// origOpToCollapsedOpMapping[2] = {0, 2}; 1173 /// origOpToCollapsedOpMapping[3] = {1, 0}; 1174 /// origOpToCollapsedOpMapping[4] = {1, 1}; 1175 /// ``` 1176 /// 1177 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const { 1178 return origOpToCollapsedOpIterationDim; 1179 } 1180 1181 /// Return the collapsed op iteration domain rank. 1182 unsigned getCollapsedOpIterationRank() const { 1183 return collapsedOpToOrigOpIterationDim.size(); 1184 } 1185 1186 private: 1187 /// Map from the iteration domain index in collapsed op to the iteration 1188 /// domain indices in the original op. 1189 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim; 1190 1191 /// Map from iteration domain index in the original op to the iteration domain 1192 /// index in the collapsed op. 1193 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim; 1194 }; 1195 } // namespace 1196 1197 /// Get the iterator types for the collapsed operation given the original 1198 /// iterator types and collapsed dimensions. 1199 static SmallVector<StringRef> 1200 getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes, 1201 const CollapsingInfo &collapsingInfo) { 1202 SmallVector<StringRef> collapsedIteratorTypes; 1203 for (ReassociationIndicesRef foldedIterDims : 1204 collapsingInfo.getCollapsedOpToOrigOpMapping()) { 1205 assert(!foldedIterDims.empty() && 1206 "reassociation indices expected to have non-empty sets"); 1207 // Just pick the iterator type of the first folded dim. Pre-condition checks 1208 // expected to have checked that iterator types of all folded dimensions are 1209 // the same. 1210 collapsedIteratorTypes.push_back( 1211 iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue()); 1212 } 1213 return collapsedIteratorTypes; 1214 } 1215 1216 /// Compute the indexing map in the collapsed op that corresponds to the given 1217 /// `indexingMap` of the original operation. 1218 static AffineMap 1219 getCollapsedOpIndexingMap(AffineMap indexingMap, 1220 const CollapsingInfo &collapsingInfo) { 1221 MLIRContext *context = indexingMap.getContext(); 1222 assert(indexingMap.isProjectedPermutation() && 1223 "expected indexing map to be projected permutation"); 1224 SmallVector<AffineExpr> resultExprs; 1225 auto origOpToCollapsedOpMapping = 1226 collapsingInfo.getOrigOpToCollapsedOpMapping(); 1227 for (auto expr : indexingMap.getResults()) { 1228 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 1229 // If the dim is not the first of the collapsed dim, do nothing. 1230 if (origOpToCollapsedOpMapping[dim].second != 0) 1231 continue; 1232 // The next n-dims are guaranteed to be collapsed. So just use the 1233 // iteration dimension of the collapsed op. 1234 resultExprs.push_back( 1235 getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context)); 1236 } 1237 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0, 1238 resultExprs, context); 1239 } 1240 1241 /// Return the `reassociation` indices to use to collapse the operand when the 1242 /// iteration space of a generic op is collapsed. 1243 static SmallVector<ReassociationIndices> 1244 getOperandReassociation(AffineMap indexingMap, 1245 const CollapsingInfo &collapsingInfo) { 1246 unsigned counter = 0; 1247 SmallVector<ReassociationIndices> operandReassociation; 1248 auto origOpToCollapsedOpMapping = 1249 collapsingInfo.getOrigOpToCollapsedOpMapping(); 1250 auto collapsedOpToOrigOpMapping = 1251 collapsingInfo.getCollapsedOpToOrigOpMapping(); 1252 while (counter < indexingMap.getNumResults()) { 1253 unsigned dim = 1254 indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition(); 1255 if (origOpToCollapsedOpMapping[dim].second == 0) { 1256 // This is the start of a collapsed dimensions of the iteration that 1257 // is gauranteed to be preserved in the indexing map. The number of folded 1258 // dims is obtained from the collapsed op to original op mapping. 1259 unsigned numFoldedDims = 1260 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first] 1261 .size(); 1262 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims); 1263 operandReassociation.emplace_back(range.begin(), range.end()); 1264 counter += numFoldedDims; 1265 } 1266 } 1267 return operandReassociation; 1268 } 1269 1270 /// Get the new value to use for a given `OpOperand` in the collapsed operation. 1271 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, 1272 OpOperand *opOperand, 1273 const CollapsingInfo &collapsingInfo, 1274 OpBuilder &builder) { 1275 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 1276 SmallVector<ReassociationIndices> operandReassociation = 1277 getOperandReassociation(indexingMap, collapsingInfo); 1278 1279 // If the number of entries in the reassocation for the operand is same as the 1280 // number of results of the indexing map, then nothing to do for this operand. 1281 Value operand = opOperand->get(); 1282 if (operandReassociation.size() == indexingMap.getNumResults()) 1283 return operand; 1284 1285 // Insert a reshape to collapse the dimensions. 1286 auto reshapeOp = builder.create<tensor::CollapseShapeOp>( 1287 loc, operand, operandReassociation); 1288 return reshapeOp.getResult(); 1289 } 1290 1291 /// Modify the `linalg.index` operations in the original generic op, to its 1292 /// value in the collapsed operation. 1293 void generateCollapsedIndexingRegion(Location loc, Block *block, 1294 const CollapsingInfo &collapsingInfo, 1295 ValueRange loopRange, 1296 PatternRewriter &rewriter) { 1297 OpBuilder::InsertionGuard g(rewriter); 1298 rewriter.setInsertionPointToStart(block); 1299 1300 // Collect all the original index ops. 1301 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>()); 1302 1303 // For each folded dimension list resolve the original induction variable 1304 // values in terms of the folded dimension induction variable. 1305 // i_{folded} = (i_0 * d1 + i1) * d2 + i2. 1306 // can be inverted to 1307 // i2 = i_{folded} % d2 1308 // i1 = (i_{folded} / d2) % d1 1309 // i0 = i_{folded} / (d1 * d2) 1310 llvm::DenseMap<unsigned, Value> indexReplacementVals; 1311 for (auto &foldedDims : 1312 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) { 1313 ReassociationIndicesRef foldedDimsRef(foldedDims.value()); 1314 Value newIndexVal = 1315 rewriter.create<linalg::IndexOp>(loc, foldedDims.index()); 1316 for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) { 1317 indexReplacementVals[dim] = 1318 rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]); 1319 newIndexVal = 1320 rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]); 1321 } 1322 indexReplacementVals[foldedDims.value().front()] = newIndexVal; 1323 } 1324 1325 for (auto indexOp : indexOps) { 1326 auto dim = indexOp.dim(); 1327 rewriter.replaceOp(indexOp, indexReplacementVals[dim]); 1328 } 1329 } 1330 1331 /// Implementation of fusion with reshape operation by collapsing dimensions. 1332 static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims( 1333 GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims, 1334 OpOperand *fusableOpOperand, PatternRewriter &rewriter) { 1335 // Bail on trivial no-op cases. 1336 if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() || 1337 llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { 1338 return foldedDims.size() <= 1; 1339 })) 1340 return failure(); 1341 1342 CollapsingInfo collapsingInfo; 1343 if (failed(collapsingInfo.initialize(genericOp.getNumLoops(), 1344 foldedIterationDims))) { 1345 return rewriter.notifyMatchFailure( 1346 genericOp, "illegal to collapse specified dimensions"); 1347 } 1348 1349 // Get the iterator types for the operand. 1350 SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes( 1351 genericOp.iterator_types().getValue(), collapsingInfo); 1352 1353 // Get the indexing maps. 1354 auto indexingMaps = llvm::to_vector( 1355 llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) { 1356 return getCollapsedOpIndexingMap(map, collapsingInfo); 1357 })); 1358 1359 Location loc = genericOp->getLoc(); 1360 1361 // Get the input operands. 1362 auto inputOperands = llvm::to_vector( 1363 llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) { 1364 return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo, 1365 rewriter); 1366 })); 1367 1368 // Get the output operands and result types. 1369 SmallVector<Type> resultTypes; 1370 SmallVector<Value> outputOperands; 1371 resultTypes.reserve(genericOp.getNumOutputs()); 1372 outputOperands.reserve(genericOp.getNumOutputs()); 1373 for (OpOperand *output : genericOp.getOutputOperands()) { 1374 Value newOutput = 1375 getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter); 1376 outputOperands.push_back(newOutput); 1377 resultTypes.push_back(newOutput.getType()); 1378 } 1379 1380 // Create the generic op. 1381 auto collapsedGenericOp = rewriter.create<linalg::GenericOp>( 1382 loc, resultTypes, inputOperands, outputOperands, indexingMaps, 1383 iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); 1384 Block *origOpBlock = &genericOp->getRegion(0).front(); 1385 Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front(); 1386 rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, 1387 collapsedOpBlock->getArguments()); 1388 1389 if (collapsedGenericOp.hasIndexSemantics()) { 1390 // Collect the loop range of the generic op. 1391 OpBuilder::InsertionGuard g(rewriter); 1392 rewriter.setInsertionPoint(collapsedGenericOp); 1393 SmallVector<Range> loopRanges = 1394 cast<LinalgOp>(genericOp.getOperation()) 1395 .createLoopRanges(rewriter, genericOp.getLoc()); 1396 assert(llvm::all_of(loopRanges, 1397 [](Range range) { 1398 return matchPattern(range.offset, m_Zero()) && 1399 matchPattern(range.stride, m_One()); 1400 }) && 1401 "expected all loop ranges to have zero start and unit stride"); 1402 SmallVector<Value> loopBound = llvm::to_vector( 1403 llvm::map_range(loopRanges, [](Range range) { return range.size; })); 1404 generateCollapsedIndexingRegion(loc, 1405 &collapsedGenericOp->getRegion(0).front(), 1406 collapsingInfo, loopBound, rewriter); 1407 } 1408 1409 // Insert expanding reshape for the result to get back the original result 1410 // type. 1411 SmallVector<Value> results; 1412 for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) { 1413 Value collapsedOpResult = 1414 collapsedGenericOp->getResult(originalResult.index()); 1415 auto originalResultType = 1416 originalResult.value().getType().cast<ShapedType>(); 1417 auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>(); 1418 if (collapsedOpResultType.getRank() != originalResultType.getRank()) { 1419 AffineMap indexingMap = 1420 genericOp.getTiedIndexingMapForResult(originalResult.value()); 1421 SmallVector<ReassociationIndices> reassociation = 1422 getOperandReassociation(indexingMap, collapsingInfo); 1423 Value result = rewriter.create<tensor::ExpandShapeOp>( 1424 loc, originalResultType, collapsedOpResult, reassociation); 1425 results.push_back(result); 1426 } else { 1427 results.push_back(collapsedOpResult); 1428 } 1429 } 1430 return results; 1431 } 1432 1433 namespace { 1434 1435 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by 1436 /// contracting dimensions of the loop. 1437 class FoldWithProducerReshapeOpByCollapsing 1438 : public OpRewritePattern<GenericOp> { 1439 public: 1440 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context, 1441 ControlFusionFn foldReshapes, 1442 PatternBenefit benefit = 1) 1443 : OpRewritePattern<GenericOp>(context, benefit), 1444 controlFoldingReshapes(std::move(foldReshapes)) {} 1445 1446 LogicalResult matchAndRewrite(GenericOp genericOp, 1447 PatternRewriter &rewriter) const override { 1448 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 1449 tensor::ExpandShapeOp reshapeOp = 1450 opOperand->get().getDefiningOp<tensor::ExpandShapeOp>(); 1451 if (!reshapeOp) 1452 continue; 1453 1454 SmallVector<ReassociationIndices> collapsableIterationDims = 1455 getCollapsableIterationSpaceDims(genericOp, opOperand, 1456 reshapeOp.getReassociationIndices()); 1457 if (collapsableIterationDims.empty() || 1458 !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) { 1459 continue; 1460 } 1461 1462 Optional<SmallVector<Value>> replacements = 1463 collapseGenericOpIterationDims(genericOp, collapsableIterationDims, 1464 opOperand, rewriter); 1465 if (!replacements) { 1466 return rewriter.notifyMatchFailure( 1467 genericOp, "failed to do the fusion by collapsing transformation"); 1468 } 1469 1470 rewriter.replaceOp(genericOp, *replacements); 1471 return success(); 1472 } 1473 return failure(); 1474 } 1475 1476 private: 1477 ControlFusionFn controlFoldingReshapes; 1478 }; 1479 } // namespace 1480 1481 //===---------------------------------------------------------------------===// 1482 // Methods and patterns that fuse constants with linalg.generic operations. 1483 //===---------------------------------------------------------------------===// 1484 1485 namespace { 1486 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not 1487 /// handle cases where the constant is not single-valued. 1488 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> { 1489 public: 1490 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1) 1491 : OpRewritePattern<GenericOp>(context, benefit) {} 1492 1493 LogicalResult matchAndRewrite(GenericOp genericOp, 1494 PatternRewriter &rewriter) const override { 1495 if (!genericOp.hasTensorSemantics()) 1496 return failure(); 1497 for (OpOperand *opOperand : genericOp.getInputOperands()) { 1498 Operation *def = opOperand->get().getDefiningOp(); 1499 Attribute constantAttr; 1500 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { 1501 { 1502 DenseElementsAttr splatAttr; 1503 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) && 1504 splatAttr.isSplat() && 1505 splatAttr.getType().getElementType().isIntOrFloat()) { 1506 constantAttr = splatAttr.getSplatValue<Attribute>(); 1507 return true; 1508 } 1509 } 1510 { 1511 IntegerAttr intAttr; 1512 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) { 1513 constantAttr = intAttr; 1514 return true; 1515 } 1516 } 1517 { 1518 FloatAttr floatAttr; 1519 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) { 1520 constantAttr = floatAttr; 1521 return true; 1522 } 1523 } 1524 return false; 1525 }; 1526 1527 auto resultValue = opOperand->get().dyn_cast<OpResult>(); 1528 if (!def || !resultValue || !isScalarOrSplatConstantOp(def)) 1529 continue; 1530 1531 // The operands and the indexing_maps of the fused operation the same as 1532 // the operands and indexing_maps of the generic operations with the 1533 // values at the constant index dropped. 1534 SmallVector<AffineMap> fusedIndexMaps; 1535 SmallVector<Value> fusedOperands; 1536 SmallVector<Location> fusedLocs{genericOp.getLoc()}; 1537 fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); 1538 fusedOperands.reserve(genericOp.getNumInputs()); 1539 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); 1540 for (OpOperand *inputOperand : genericOp.getInputOperands()) { 1541 if (inputOperand == opOperand) 1542 continue; 1543 Value inputValue = inputOperand->get(); 1544 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); 1545 fusedOperands.push_back(inputValue); 1546 fusedLocs.push_back(inputValue.getLoc()); 1547 } 1548 for (OpOperand *outputOperand : genericOp.getOutputOperands()) 1549 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); 1550 1551 // Check if the operation shapes to loops map is computable. 1552 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1553 return rewriter.notifyMatchFailure( 1554 genericOp, "fused op loop bound computation failed"); 1555 } 1556 1557 // Create a constant scalar value from the splat constant. 1558 Value scalarConstant = rewriter.create<arith::ConstantOp>( 1559 def->getLoc(), constantAttr, constantAttr.getType()); 1560 1561 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 1562 auto fusedOp = rewriter.create<GenericOp>( 1563 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), 1564 /*inputs=*/fusedOperands, 1565 /*outputs=*/outputOperands, 1566 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1567 genericOp.iterator_types(), 1568 /*doc=*/nullptr, 1569 /*library_call=*/nullptr); 1570 1571 // Map the block argument corresponding to the replaced argument with the 1572 // scalar constant. 1573 Region ®ion = genericOp->getRegion(0); 1574 Block &entryBlock = *region.begin(); 1575 BlockAndValueMapping mapping; 1576 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), 1577 scalarConstant); 1578 Region &fusedRegion = fusedOp->getRegion(0); 1579 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), 1580 mapping); 1581 rewriter.replaceOp(genericOp, fusedOp->getResults()); 1582 return success(); 1583 } 1584 return failure(); 1585 } 1586 }; 1587 1588 } // namespace 1589 1590 //===---------------------------------------------------------------------===// 1591 // Miscellaneous patterns that help fusion. 1592 //===---------------------------------------------------------------------===// 1593 1594 namespace { 1595 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if 1596 /// the value of the `outs` operand is not used within the op. This is only 1597 /// implemented for `linalg.generic` operations for now, but should hold for all 1598 /// linalg structured ops. 1599 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> { 1600 using OpRewritePattern<GenericOp>::OpRewritePattern; 1601 1602 LogicalResult matchAndRewrite(GenericOp op, 1603 PatternRewriter &rewriter) const override { 1604 rewriter.startRootUpdate(op); 1605 bool modifiedOutput = false; 1606 Location loc = op.getLoc(); 1607 for (OpOperand *opOperand : op.getOutputOperands()) { 1608 if (!op.payloadUsesValueFromOperand(opOperand)) { 1609 Value operandVal = opOperand->get(); 1610 auto operandType = operandVal.getType().dyn_cast<RankedTensorType>(); 1611 if (!operandType) 1612 continue; 1613 1614 // If outs is sparse, leave it to the sparse compiler. 1615 if (sparse_tensor::getSparseTensorEncoding(operandVal.getType())) 1616 continue; 1617 1618 // If outs is already an `init_tensor` operation, nothing to do. 1619 auto definingOp = operandVal.getDefiningOp<InitTensorOp>(); 1620 if (definingOp) 1621 continue; 1622 modifiedOutput = true; 1623 SmallVector<Value> dynamicDims; 1624 for (const auto &dim : llvm::enumerate(operandType.getShape())) { 1625 if (dim.value() != ShapedType::kDynamicSize) 1626 continue; 1627 dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>( 1628 loc, operandVal, dim.index())); 1629 } 1630 Value initTensor = rewriter.create<InitTensorOp>( 1631 loc, dynamicDims, operandType.getShape(), 1632 operandType.getElementType()); 1633 op->setOperand(opOperand->getOperandNumber(), initTensor); 1634 } 1635 } 1636 if (!modifiedOutput) { 1637 rewriter.cancelRootUpdate(op); 1638 return failure(); 1639 } 1640 rewriter.finalizeRootUpdate(op); 1641 return success(); 1642 } 1643 }; 1644 1645 /// Fold linalg.fill into linalg.generic 1646 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> { 1647 using OpRewritePattern<GenericOp>::OpRewritePattern; 1648 1649 LogicalResult matchAndRewrite(GenericOp genericOp, 1650 PatternRewriter &rewriter) const override { 1651 if (!genericOp.hasTensorSemantics()) 1652 return failure(); 1653 bool fillFound = false; 1654 Block &payload = genericOp.region().front(); 1655 for (OpOperand *opOperand : genericOp.getInputOperands()) { 1656 if (!genericOp.payloadUsesValueFromOperand(opOperand)) 1657 continue; 1658 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>(); 1659 if (!fillOp) 1660 continue; 1661 fillFound = true; 1662 payload.getArgument(opOperand->getOperandNumber()) 1663 .replaceAllUsesWith(fillOp.value()); 1664 } 1665 return success(fillFound); 1666 } 1667 }; 1668 } // namespace 1669 1670 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( 1671 RewritePatternSet &patterns, 1672 const ControlFusionFn &controlFoldingReshapes) { 1673 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(), 1674 controlFoldingReshapes); 1675 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), 1676 controlFoldingReshapes); 1677 } 1678 1679 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( 1680 RewritePatternSet &patterns, 1681 const ControlFusionFn &controlFoldingReshapes) { 1682 patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(), 1683 controlFoldingReshapes); 1684 } 1685 1686 void mlir::linalg::populateElementwiseOpsFusionPatterns( 1687 RewritePatternSet &patterns, 1688 const ControlFusionFn &controlElementwiseOpsFusion) { 1689 auto *context = patterns.getContext(); 1690 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion); 1691 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant, 1692 RemoveOutsDependency>(context); 1693 } 1694 1695 //===---------------------------------------------------------------------===// 1696 // Passes 1697 //===---------------------------------------------------------------------===// 1698 1699 namespace { 1700 1701 /// Pass that fuses generic ops on tensors. Used only for testing. 1702 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the 1703 // patterns added here heavily depends on the cost function used. Having an 1704 // opinionated pass of this form is not recommended. Deprecate this pass in 1705 // favor of test passes that check the functionality of each of the patterns 1706 // added here individually. 1707 struct LinalgElementwiseOpFusionPass 1708 : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> { 1709 void runOnOperation() override { 1710 Operation *op = getOperation(); 1711 MLIRContext *context = op->getContext(); 1712 RewritePatternSet patterns(context); 1713 1714 // Add folding with reshape by expansion patterns. 1715 ControlFusionFn defaultControlFn = [](const OpResult &producer, 1716 const OpOperand &consumer) { 1717 return producer.hasOneUse(); 1718 }; 1719 1720 // Add elementwise op fusion patterns. 1721 populateElementwiseOpsFusionPatterns(patterns, defaultControlFn); 1722 populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn); 1723 1724 // General canonicalization patterns. 1725 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 1726 GenericOp::getCanonicalizationPatterns(patterns, context); 1727 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); 1728 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); 1729 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns( 1730 patterns); 1731 1732 // Add constant folding patterns. 1733 populateConstantFoldLinalgOperations(patterns, defaultControlFn); 1734 1735 // Use TopDownTraversal for compile time reasons 1736 GreedyRewriteConfig grc; 1737 grc.useTopDownTraversal = true; 1738 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), 1739 grc); 1740 } 1741 }; 1742 1743 } // namespace 1744 1745 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() { 1746 return std::make_unique<LinalgElementwiseOpFusionPass>(); 1747 } 1748