1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion on tensors operations pass. 10 // 11 //===----------------------------------------------------------------------===// 12 #include <utility> 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/IR/AffineExpr.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Support/LLVM.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 using namespace mlir; 28 using namespace mlir::linalg; 29 30 //===---------------------------------------------------------------------===// 31 // Methods and patterns that fuse elementwise `linalg.generic` operations. 32 //===---------------------------------------------------------------------===// 33 34 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 35 /// the `producer` to use in the fused operation given the indexing map of the 36 /// result of the producer in the consumer. 37 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 38 OpOperand *producerOpOperand, AffineMap producerResultIndexMap, 39 AffineMap fusedConsumerArgIndexMap) { 40 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 41 // from consumer loop -> consumer arg tensor index/producer result tensor 42 // index. The fused loop is same as the consumer loop. For each producer arg 43 // the indexing map to be computed is a map from consumer loop -> producer 44 // arg tensor index. 45 // producerResultIndexMap is a map from producer loop -> tensor index. 46 // Compute the inverse to get map from tensor index -> producer loop. 47 // The inverse is a map from producer result tensor index -> producer loop. 48 AffineMap invProducerResultIndexMap = 49 inversePermutation(producerResultIndexMap); 50 assert(invProducerResultIndexMap && 51 "expected producer result indexig map to be invertible"); 52 53 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner()); 54 // argMap is a map from producer loop -> producer arg tensor index. 55 AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); 56 57 // Compose argMap with invProducerResultIndexMap to get a map from 58 // producer result tensor index -> producer arg tensor index. 59 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 60 61 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 62 // consumer loop/ fused loop -> producer arg tensor index. 63 return t1.compose(fusedConsumerArgIndexMap); 64 } 65 66 /// Conditions for elementwise fusion of generic operations. 67 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, 68 OpOperand *consumerOpOperand) { 69 // Producer and consumer must have tensor semantics. 70 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 71 return false; 72 73 // Verify that 74 // - the producer has all "parallel" iterator type. 75 if (producer.getNumParallelLoops() != producer.getNumLoops()) 76 return false; 77 78 // Only allow fusing the producer of an input operand for now. 79 // TODO: allow fusing the producer of an output operand. 80 if (!consumer.isInputTensor(consumerOpOperand)) 81 return false; 82 83 // Get the consumer index map. The number of results of the consumer index 84 // map must match the number of loops of the producer. 85 AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); 86 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 87 return false; 88 89 // Currently support only operations with single result. 90 if (producer.getNumOutputs() != 1) 91 return false; 92 93 // Finally the index_map for the result must be invertible. For now just 94 // verify it is a permutation. 95 AffineMap producerResultIndexMap = 96 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 97 if (!producerResultIndexMap.isPermutation()) 98 return false; 99 100 // Ensure that the fusion does not remove size information required to 101 // get the loop bounds. For non-reduction generics, this is trivially the 102 // case due to the output operand. For reductions, we need to check that after 103 // the fusion, each loop dimension has at least one input that defines it. 104 if ((consumer.getNumReductionLoops())) { 105 BitVector coveredDims(consumer.getNumLoops(), false); 106 107 auto addToCoveredDims = [&](AffineMap map) { 108 for (auto result : map.getResults()) 109 if (auto dimExpr = result.dyn_cast<AffineDimExpr>()) 110 coveredDims[dimExpr.getPosition()] = true; 111 }; 112 113 for (auto pair : 114 llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) { 115 Value operand = std::get<0>(pair); 116 if (operand == consumerOpOperand->get()) 117 continue; 118 AffineMap operandMap = std::get<1>(pair); 119 addToCoveredDims(operandMap); 120 } 121 122 for (OpOperand *operand : producer.getInputOperands()) { 123 AffineMap newIndexingMap = 124 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 125 operand, producerResultIndexMap, consumerIndexMap); 126 addToCoveredDims(newIndexingMap); 127 } 128 if (!coveredDims.all()) 129 return false; 130 } 131 132 return true; 133 } 134 135 /// Generate the region of the fused tensor operation. The region of the fused 136 /// op must be empty. 137 static void 138 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, 139 AffineMap consumerToProducerLoopsMap, 140 OpOperand *consumerOpOperand, 141 unsigned nloops) { 142 auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp()); 143 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 144 // Build the region of the fused op. 145 Block &producerBlock = producer->getRegion(0).front(); 146 Block &consumerBlock = consumer->getRegion(0).front(); 147 Block *fusedBlock = new Block(); 148 fusedOp.region().push_back(fusedBlock); 149 BlockAndValueMapping mapper; 150 OpBuilder::InsertionGuard guard(rewriter); 151 rewriter.setInsertionPointToStart(fusedBlock); 152 153 // 2. Add an index operation for every fused loop dimension and use the 154 // `consumerToProducerLoopsMap` to map the producer indices. 155 if (producer.hasIndexSemantics()) { 156 // Add an index operation for every fused loop dimension. 157 unsigned numFusedOpLoops = 158 std::max(producer.getNumLoops(), consumer.getNumLoops()); 159 SmallVector<Value> fusedIndices; 160 fusedIndices.reserve(numFusedOpLoops); 161 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops), 162 std::back_inserter(fusedIndices), [&](uint64_t dim) { 163 return rewriter.create<IndexOp>(producer.getLoc(), dim); 164 }); 165 for (IndexOp indexOp : 166 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) { 167 Value newIndex = rewriter.create<mlir::AffineApplyOp>( 168 producer.getLoc(), 169 consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices); 170 mapper.map(indexOp.getResult(), newIndex); 171 } 172 } 173 // TODO: allow fusing the producer of an output operand. 174 assert(consumer.isInputTensor(consumerOpOperand) && 175 "expected producer of input operand"); 176 // 3. Consumer input operands up to consumerIdx (exclusive). 177 for (BlockArgument bbArg : consumerBlock.getArguments().take_front( 178 consumerOpOperand->getOperandNumber())) // input assumption. 179 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 180 181 // Replacing consumerIdx requires getting the cloned, yielded, value from 182 // the (cloned) producer block. This happens in step 9. 183 184 // 4. Splice in producer's input operands. 185 for (BlockArgument bbArg : 186 producerBlock.getArguments().take_front(producer.getNumInputs())) 187 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 188 189 // 4.b. Producer output operand/map that is fused needs to be mapped to the 190 // producer bbArg if it is an "initTensor" (i.e. its value is actually read). 191 assert(producer->getNumResults() == 1 && "expected single result producer"); 192 if (producer.isInitTensor(producer.getOutputOperand(0))) { 193 BlockArgument bbArg = producerBlock.getArguments() 194 .drop_front(producer.getNumInputs()) 195 // TODO: bbArg index of 196 .front(); 197 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 198 } 199 // 5. Remaining consumer's input operands (drop past index `consumerIdx`). 200 for (BlockArgument bbArg : 201 consumerBlock.getArguments() 202 .take_front(consumer.getNumInputs()) 203 .drop_front(consumerOpOperand->getOperandNumber() + 1)) 204 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 205 // 6. All of consumer's output operands. 206 for (BlockArgument bbArg : 207 consumerBlock.getArguments().take_back(consumer.getNumOutputs())) 208 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 209 // 7. All of producer's output operands except the one fused. 210 // TODO: allow fusion of multi-result producers. 211 assert(producer->getNumResults() == 1 && "expected single result producer"); 212 213 // 8. Clone all producer operations except for the yield and index operations 214 // to the fused operation. 215 for (auto &op : producerBlock.without_terminator()) { 216 if (!isa<IndexOp>(op)) 217 rewriter.clone(op, mapper); 218 } 219 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just 220 // forward the yield operand. 221 auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator()); 222 // TODO: allow fusion of multi-result producers. 223 assert(producer->getNumResults() == 1 && "expected single result producer"); 224 unsigned producerResultNumber = 0; 225 Value replacement = 226 mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); 227 // Sanity checks, if replacement is not already in the mapper then it must be 228 // produced outside. 229 if (replacement == yieldOp.getOperand(producerResultNumber)) { 230 if (auto bb = replacement.dyn_cast<BlockArgument>()) 231 assert(bb.getOwner() != &producerBlock && 232 "yielded block argument must have been mapped"); 233 else 234 assert(!producer->isAncestor(replacement.getDefiningOp()) && 235 "yielded value must have been mapped"); 236 } 237 mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), 238 replacement); 239 // 10. Clone operations from the consumer to the fused op. 240 for (auto &op : consumerBlock.getOperations()) 241 rewriter.clone(op, mapper); 242 243 // Sanity checks. 244 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && 245 "Ill-formed GenericOp region"); 246 } 247 248 static Optional<SmallVector<Value>> 249 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, 250 const ControlElementwiseOpsFusionFn &controlFn, 251 PatternRewriter &rewriter) { 252 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 253 if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || 254 !controlFn(producer->getResult(0), *consumerOpOperand)) 255 return llvm::None; 256 257 // TODO: allow fusing the producer of an output operand. 258 assert(consumer.isInputTensor(consumerOpOperand) && 259 "expected producer of input operand"); 260 261 // Compute the fused operands list and indexing maps. 262 SmallVector<Value> fusedOperands; 263 SmallVector<AffineMap> fusedIndexMaps; 264 fusedOperands.reserve(producer->getNumOperands() + 265 consumer->getNumOperands()); 266 fusedIndexMaps.reserve(producer->getNumOperands() + 267 consumer->getNumOperands()); 268 // In the following, numbering matches that of `generateFusedTensorOpRegion`. 269 // 3. Consumer input operands/maps up to consumerIdx (exclusive). 270 SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands(); 271 SmallVector<OpOperand *>::iterator it = 272 llvm::find(consumerInputs, consumerOpOperand); 273 assert(it != consumerInputs.end() && "expected to find the consumer operand"); 274 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { 275 fusedOperands.push_back(opOperand->get()); 276 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 277 } 278 // 4. Splice in producer's input operands/maps. 279 assert(producer->getNumResults() == 1 && "expected single result producer"); 280 AffineMap producerResultIndexMap = 281 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 282 for (OpOperand *opOperand : producer.getInputOperands()) { 283 fusedOperands.push_back(opOperand->get()); 284 // Compute indexing maps for the producer args in the fused operation. 285 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 286 opOperand, producerResultIndexMap, 287 consumer.getTiedIndexingMap(consumerOpOperand)); 288 fusedIndexMaps.push_back(map); 289 } 290 // 4.b. Producer output operand/map that is fused needs to be passed if it is 291 // an "initTensor" (i.e. its value is actually read). 292 assert(producer->getNumResults() == 1 && "expected single result producer"); 293 if (producer.isInitTensor(producer.getOutputOperand(0))) { 294 fusedOperands.push_back(producer.getOutputOperand(0)->get()); 295 // Compute indexing maps for the producer args in the fused operation. 296 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 297 producer.getOutputOperand(0), producerResultIndexMap, 298 consumer.getTiedIndexingMap(consumerOpOperand)); 299 fusedIndexMaps.push_back(map); 300 } 301 // 5. Remaining consumer's input operands/maps (drop past index 302 // `consumerIdx`). 303 for (OpOperand *opOperand : 304 llvm::make_range(std::next(it), consumerInputs.end())) { 305 fusedOperands.push_back(opOperand->get()); 306 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 307 } 308 // 6. All of consumer's output operands (skip operands: added by the builder). 309 for (OpOperand *opOperand : consumer.getOutputOperands()) 310 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 311 // 7. All of producer's output operands/maps except the one fused. 312 // TODO: allow fusion of multi-result producers. 313 assert(producer->getNumResults() == 1 && "expected single result producer"); 314 315 // Generate the fused op. 316 SmallVector<Value> consumerOutputs = consumer.getOutputOperands(); 317 auto fusedOp = rewriter.create<GenericOp>( 318 consumer.getLoc(), consumer->getResultTypes(), 319 /*inputs=*/fusedOperands, 320 // TODO: handle outputs. 321 consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 322 consumer.iterator_types(), 323 /*doc=*/nullptr, 324 /*library_call=*/nullptr); 325 if (!fusedOp.getShapesToLoopsMap()) { 326 // Fused op has invalid indexing maps. Typically this means something is off 327 // in the input, but going ahead here would result in verification errors. 328 // So cleanup and abort. 329 rewriter.eraseOp(fusedOp); 330 return llvm::None; 331 } 332 333 // Construct an AffineMap from consumer loops to producer loops. 334 // consumer loop -> tensor index 335 AffineMap consumerResultIndexMap = 336 consumer.getTiedIndexingMap(consumerOpOperand); 337 // tensor index -> producer loop 338 AffineMap invProducerResultIndexMap = 339 inversePermutation(producerResultIndexMap); 340 assert(invProducerResultIndexMap && 341 "expected producer result indexig map to be invertible"); 342 // consumer loop -> producer loop 343 AffineMap consumerToProducerLoopsMap = 344 invProducerResultIndexMap.compose(consumerResultIndexMap); 345 346 generateFusedElementwiseOpRegion(rewriter, fusedOp, 347 consumerToProducerLoopsMap, 348 consumerOpOperand, consumer.getNumLoops()); 349 return SmallVector<Value>(fusedOp->getResults()); 350 } 351 352 static Optional<SmallVector<Value>> 353 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, 354 GenericOp producer, 355 const ControlElementwiseOpsFusionFn &controlFn) { 356 if (producer->getNumResults() != 1) 357 return llvm::None; 358 359 return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, 360 rewriter); 361 } 362 363 namespace { 364 /// Patterns to fuse a generic op, with the producer of its operands. 365 class FuseElementwiseOps : public OpRewritePattern<GenericOp> { 366 public: 367 FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, 368 PatternBenefit benefit = 1) 369 : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {} 370 371 LogicalResult matchAndRewrite(GenericOp genericOp, 372 PatternRewriter &rewriter) const override { 373 // Find the first operand that is defined by another generic op on tensors. 374 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 375 auto producer = 376 dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp()); 377 if (!producer || !producer.hasTensorSemantics()) 378 continue; 379 Optional<SmallVector<Value>> fusedOpResults = 380 fuseElementwiseOps(rewriter, opOperand, producer, controlFn); 381 if (fusedOpResults) { 382 rewriter.replaceOp(genericOp, *fusedOpResults); 383 return success(); 384 } 385 } 386 return failure(); 387 } 388 389 private: 390 ControlElementwiseOpsFusionFn controlFn; 391 }; 392 } // namespace 393 394 //===---------------------------------------------------------------------===// 395 // Methods and patterns that fuse reshape ops with elementwise operations by 396 // linearization of indexing maps. 397 //===---------------------------------------------------------------------===// 398 399 // TODO(ravishankarm): The indexing maps 400 // these produce in the general case are detrimental to transformations. 401 // These patterns are on deprecation path in favor of using fusion by 402 // collapsing, which covers the only legitimate use case of this pattern of 403 // folding unit-extent dims. 404 405 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` 406 /// provided, given the shape of the source tensor that corresponds to the 407 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions 408 /// are "row-major" ordered logically. 409 /// 410 /// For example: 411 /// 412 /// %0 = op ... : tensor<?x?x4x5xf32> 413 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` 414 /// 415 /// and reshape: 416 /// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] : 417 /// tensor<?x?x4x5xf32> into tensor<?x?xf32> 418 /// 419 /// would be rewritten into: 420 /// %0 = op ... : tensor<?x?x4x5xf32> 421 /// with output index_map 422 /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` 423 template <typename TensorReshapeOp> 424 static AffineMap linearizeCollapsedDims(AffineMap sourceMap, 425 TensorReshapeOp reshapeOp) { 426 constexpr bool isExpanding = 427 std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value; 428 ArrayRef<int64_t> sourceShape = 429 (isExpanding ? reshapeOp.getResultType().getShape() 430 : reshapeOp.getSrcType().getShape()); 431 SmallVector<AffineExpr> resultExprs; 432 ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults(); 433 MLIRContext *context = sourceMap.getContext(); 434 435 // Compute the result exprs based on the reassociation maps. 436 for (auto &indices : reshapeOp.getReassociationIndices()) { 437 // Assume that they are in-order and contiguous (already checked in 438 // verifier). 439 assert(!indices.empty()); 440 SmallVector<int64_t> sizes; 441 SmallVector<AffineExpr> dimExprs; 442 for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()), 443 sourceExprs.slice(indices[0], indices.size()))) { 444 if (std::get<0>(en) == 1) 445 continue; 446 sizes.push_back(std::get<0>(en)); 447 dimExprs.push_back(std::get<1>(en)); 448 } 449 AffineExpr linearizedExpr = 450 makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); 451 resultExprs.push_back(linearizedExpr); 452 } 453 // The new affine map cannot drop unused dimension but some new symbols may 454 // have been added. Create a map with at least as many dimensions/symbols as 455 // the original affine map. 456 int64_t maxDim = -1; 457 int64_t maxSym = -1; 458 getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym); 459 unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims()); 460 unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols()); 461 return AffineMap::get(numDims, numSyms, resultExprs, context); 462 } 463 464 // tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a 465 // producer). Fusing when operand has higher rank will require use of mods and 466 // divs in the indexing maps of the fused op which would make it non-invertible. 467 static bool isTensorReshapeOpFoldableByLinearization( 468 tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) { 469 if (!asProducer) 470 return false; 471 return useIndexMap.isPermutation(); 472 } 473 474 // tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a 475 // consumer). 476 static bool 477 isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp, 478 AffineMap useIndexMap, 479 bool asProducer) { 480 if (asProducer) 481 return false; 482 return useIndexMap.isPermutation(); 483 } 484 485 /// Check if the reshape operation is only expansion into/collapsing of 486 /// unit-dimension. 487 template <typename TensorReshapeOp> 488 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) { 489 constexpr bool isExpanding = 490 std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value; 491 ArrayRef<int64_t> expandedShape = 492 (isExpanding ? reshapeOp.getResultType().getShape() 493 : reshapeOp.getSrcType().getShape()); 494 for (auto &indices : reshapeOp.getReassociationIndices()) { 495 unsigned numUnitDims = 0; 496 for (int64_t position : indices) 497 if (expandedShape[position] == 1) 498 numUnitDims++; 499 if (numUnitDims != indices.size() - 1) 500 return false; 501 } 502 return true; 503 } 504 505 namespace { 506 /// Pattern to fold tensor_expand_shape op with its consumer by using the source 507 /// of the reshape op as the operand in the consumer (instead of the result of 508 /// the tensor_collapse_shape). The corresponding index map in the consumer 509 /// needs to be modified to linearize the folded dimension. 510 /// 511 /// For example, 512 /// 513 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 514 /// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] 515 /// tensor<?x?x?xf32> into tensor<?x?x4x?xf32> 516 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } 517 /// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ... 518 /// -> tensor<?x?x4x?xf32> 519 /// 520 /// can be folded into 521 /// 522 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> 523 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 524 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } 525 /// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ... 526 /// -> tensor<?x?x4x?xf32> 527 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 528 struct FoldProducerReshapeOpByLinearization 529 : public OpRewritePattern<GenericOp> { 530 using OpRewritePattern<GenericOp>::OpRewritePattern; 531 532 LogicalResult matchAndRewrite(GenericOp genericOp, 533 PatternRewriter &rewriter) const override { 534 if (!genericOp.hasTensorSemantics()) 535 return failure(); 536 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 537 for (const auto &en : llvm::enumerate(inputOperands)) { 538 auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>(); 539 if (!reshapeOp) 540 continue; 541 542 if (!isTensorReshapeOpFoldableByLinearization( 543 reshapeOp, genericOp.getTiedIndexingMap(en.value()), 544 /*asProducer =*/true) || 545 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 546 continue; 547 548 // Compute the fused operands list, 549 SmallVector<Value> fusedOperands = genericOp.getInputOperands(); 550 fusedOperands[en.index()] = reshapeOp.src(); 551 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 552 llvm::append_range(fusedOperands, outputOperands); 553 554 // Compute indexing_maps for the fused operation. The indexing_maps for 555 // the operands of the consumers that arent fused are the same. 556 SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps(); 557 558 // Compute the indexing map to use for the result of the producer. 559 AffineMap modifiedMap = 560 linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp); 561 // The modified map cannot have symbols. 562 if (modifiedMap.getNumSymbols()) 563 return failure(); 564 for (AffineExpr expr : modifiedMap.getResults()) { 565 if (!expr.isPureAffine()) 566 return failure(); 567 } 568 fusedIndexMaps[en.index()] = modifiedMap; 569 570 // Further check that the resulting index maps can be fused and 571 // inverted. Without this the resultant op is not legal. 572 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 573 return rewriter.notifyMatchFailure( 574 genericOp, "fused op loop bound computation failed"); 575 } 576 577 rewriter.startRootUpdate(genericOp); 578 genericOp->setOperands(fusedOperands); 579 genericOp.indexing_mapsAttr( 580 rewriter.getAffineMapArrayAttr(fusedIndexMaps)); 581 rewriter.finalizeRootUpdate(genericOp); 582 return success(); 583 } 584 return failure(); 585 } 586 }; 587 588 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its 589 /// producer. The corresponding index map in the consumer needs to be modified 590 /// to linearize the folded dimension. 591 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 592 struct FoldConsumerReshapeOpByLinearization 593 : public OpRewritePattern<TensorReshapeOp> { 594 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 595 596 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 597 PatternRewriter &rewriter) const override { 598 GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>(); 599 if (!producer || !producer.hasTensorSemantics() || 600 producer.getNumOutputs() != 1 || 601 !isTensorReshapeOpFoldableByLinearization( 602 reshapeOp, 603 producer.getTiedIndexingMap(producer.getOutputOperand(0)), 604 /*asProducer =*/false) || 605 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 606 return failure(); 607 // The indexing_maps for the operands of the fused operation are same as 608 // those for the operands of the producer. 609 SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps(); 610 611 // Compute the indexing map to use for the operand of the producer. 612 AffineMap modifiedMap = linearizeCollapsedDims( 613 producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp); 614 for (AffineExpr expr : modifiedMap.getResults()) { 615 if (!expr.isPureAffine()) { 616 return rewriter.notifyMatchFailure( 617 producer, "fused op indexing map is not affine"); 618 } 619 } 620 fusedIndexMaps.back() = modifiedMap; 621 622 // Further check that the resulting index maps can be fused and 623 // inverted. Without this the resultant op is not legal. 624 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 625 return rewriter.notifyMatchFailure( 626 producer, "fused op loop bound computation failed"); 627 } 628 629 Location loc = producer.getLoc(); 630 SmallVector<Value> inputOperands = producer.getInputOperands(); 631 Value output = rewriter.create<TensorReshapeOp>( 632 loc, producer.getOutputOperand(0)->get(), 633 reshapeOp.getReassociationExprs()); 634 auto fusedOp = rewriter.create<GenericOp>( 635 loc, reshapeOp.getResultType(), 636 /*inputs=*/inputOperands, 637 // TODO: handle outputs. 638 /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 639 producer.iterator_types(), 640 /*doc=*/nullptr, 641 /*library_call=*/nullptr); 642 auto &fusedRegion = fusedOp->getRegion(0); 643 rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion, 644 fusedRegion.begin()); 645 rewriter.replaceOp(reshapeOp, fusedOp->getResults()); 646 return success(); 647 } 648 }; 649 } // namespace 650 651 //===---------------------------------------------------------------------===// 652 // Methods and patterns that fuse reshape ops with elementwise operations by 653 // expanding the dimensionality of the elementwise operations. 654 //===---------------------------------------------------------------------===// 655 656 /// Conditions for folding a generic operation with a reshape op by expanding 657 /// the iteration space dimensionality for tensor operations. These are 658 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the 659 /// following fusion pattern. 660 /// 661 /// Consider 662 /// 663 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>) 664 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 665 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 666 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] 667 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] 668 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 669 /// 670 /// The reshape can be folded into the `genericOp` if its loop dimensionality 671 /// is increased to match the result (operand) of the tensor_expand_shape. 672 /// The indexing_map of the fused tensor in the `genericOp` and the 673 /// reassociation map helps compute the indexing maps of the modified op. 674 /// For the above example, based on the reassociation map it 675 /// can be concluded that 676 /// 677 /// - The loop used to access the first dimension of the fused tensor is split 678 /// into two. 679 /// - The loop used to access the second dimension of the fused tensor is kept 680 /// as is. 681 /// - The loop used to access the third dimension of the fused tensor is split 682 /// into three. 683 /// 684 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified 685 /// op, then 686 /// 687 /// d0 -> e0, e1 688 /// d1 -> e2, e3, e4 689 /// d2 -> e5 690 /// 691 /// substituting this, the generic op can be rewritten as 692 /// 693 /// %d = linalg.generic ins(%0, %1 : ) 694 /// indexing_maps = 695 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, 696 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, 697 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] 698 /// 699 /// Since operands to the linalg generic are now 5D, reshapes can be introduced 700 /// to make it consistent 701 /// 702 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]] 703 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 704 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]] 705 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 706 /// 707 /// The added reshapes are again expanding patterns, so they will get fused 708 /// with its producers if possible. 709 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, 710 OpOperand *fusableOpOperand) { 711 // Is fusable only if: 712 // - All the indexing maps for operands and results are projected 713 // permutations. 714 // - The fused tensor is not a scalar. 715 // - All the loops are parallel loops. 716 return genericOp.hasTensorSemantics() && 717 llvm::all_of(genericOp.indexing_maps().getValue(), 718 [](Attribute attr) { 719 return attr.cast<AffineMapAttr>() 720 .getValue() 721 .isProjectedPermutation(); 722 }) && 723 genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && 724 llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { 725 return attr.cast<StringAttr>().getValue() == 726 getParallelIteratorTypeName(); 727 }); 728 } 729 730 namespace { 731 /// Information needed to expand a generic operation to fold the reshape with 732 /// it. 733 class ExpansionInfo { 734 public: 735 // Computes the mapping from original dimensions of the op to the dimensions 736 // of the expanded op given the `indexingMap` of the fused operand/result of 737 // the generic op, the `reassocationMaps` of the reshape op and the shape of 738 // the expanded op. 739 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, 740 ArrayRef<AffineMap> reassociationMaps, 741 ArrayRef<int64_t> expandedShape, 742 ArrayRef<int64_t> collapsedShape, 743 PatternRewriter &rewriter); 744 unsigned getOrigOpNumDims() const { return reassociation.size(); } 745 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } 746 ReassociationIndicesRef getExpandedDims(unsigned i) const { 747 return reassociation[i]; 748 } 749 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { 750 return expandedShapeMap[i]; 751 } 752 ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; } 753 754 private: 755 /// Reassociation from the dimensions in the original operation to the 756 /// dimension of the expanded operation. 757 SmallVector<ReassociationIndices> reassociation; 758 /// Mapping from extent of loops in the original operation, to the extent of 759 /// loops in the expanded operation. 760 SmallVector<SmallVector<int64_t>> expandedShapeMap; 761 /// Extent of the loop in the original operation. 762 SmallVector<int64_t> originalLoopExtent; 763 unsigned expandedOpNumDims; 764 }; 765 } // namespace 766 767 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, 768 OpOperand *fusableOpOperand, 769 ArrayRef<AffineMap> reassociationMaps, 770 ArrayRef<int64_t> expandedShape, 771 ArrayRef<int64_t> collapsedShape, 772 PatternRewriter &rewriter) { 773 if (reassociationMaps.empty()) 774 return failure(); 775 AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); 776 777 Optional<SmallVector<int64_t, 4>> originalLoopRange = 778 linalgOp.getStaticLoopRanges(); 779 if (!originalLoopRange) 780 return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range"); 781 originalLoopExtent.assign(originalLoopRange->begin(), 782 originalLoopRange->end()); 783 784 reassociation.clear(); 785 expandedShapeMap.clear(); 786 // Compute the number of dimension in the expanded op that correspond to each 787 // dimension of the original op. 788 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1); 789 expandedShapeMap.resize(fusedIndexMap.getNumDims()); 790 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { 791 unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition(); 792 AffineMap foldedDims = reassociationMaps[resultExpr.index()]; 793 numExpandedDims[pos] = foldedDims.getNumResults(); 794 ArrayRef<int64_t> shape = 795 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); 796 expandedShapeMap[pos].assign(shape.begin(), shape.end()); 797 } 798 // The remaining dimensions remain the same. 799 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims())) 800 if (expandedShapeMap[i].empty()) 801 expandedShapeMap[i] = {originalLoopExtent[i]}; 802 803 // Compute reassociation map from the original op to the expanded op. 804 unsigned sum = 0; 805 reassociation.reserve(fusedIndexMap.getNumDims()); 806 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) { 807 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value()); 808 reassociation.emplace_back(seq.begin(), seq.end()); 809 sum += numFoldedDim.value(); 810 } 811 expandedOpNumDims = sum; 812 return success(); 813 } 814 815 /// Epanding the body of a linalg operation requires adaptations of the accessed 816 /// loop indices. Specifically, access of indices in the original operation need 817 /// to be replaced with linearizations of indices in the expanded op. That 818 /// requires the shape of the expanded dimensions to be static (at least all but 819 /// the most significant). For now check that these are all statically sized. 820 /// Note that this could be extended to handle dynamic case, but the 821 /// implementation below uses `affine.apply` which seems to have issues when the 822 /// shapes are not static. 823 static LogicalResult isGenericOpExpandable(GenericOp genericOp, 824 const ExpansionInfo &expansionInfo, 825 PatternRewriter &rewriter) { 826 if (!genericOp.hasIndexSemantics()) 827 return success(); 828 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 829 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 830 if (expandedShape.size() == 1) 831 continue; 832 for (int64_t shape : expandedShape.drop_front()) { 833 if (ShapedType::isDynamic(shape)) { 834 return rewriter.notifyMatchFailure( 835 genericOp, "cannot expand due to index semantics and dynamic dims"); 836 } 837 } 838 } 839 return success(); 840 } 841 842 /// Return the indexing map to use in the expanded op for a given the 843 /// `indexingMap` of the original operation. 844 static AffineMap 845 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, 846 const ExpansionInfo &expansionInfo) { 847 SmallVector<AffineExpr> newExprs; 848 for (AffineExpr expr : indexingMap.getResults()) { 849 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 850 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>( 851 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { 852 return builder.getAffineDimExpr(static_cast<unsigned>(v)); 853 })); 854 newExprs.append(expandedExprs.begin(), expandedExprs.end()); 855 } 856 return AffineMap::get(expansionInfo.getExpandedOpNumDims(), 857 indexingMap.getNumSymbols(), newExprs, 858 builder.getContext()); 859 } 860 861 /// Return the type of the operand/result to use in the expanded op given the 862 /// type in the original op. 863 static RankedTensorType getExpandedType(RankedTensorType originalType, 864 AffineMap indexingMap, 865 const ExpansionInfo &expansionInfo) { 866 SmallVector<int64_t> expandedShape; 867 for (AffineExpr expr : indexingMap.getResults()) { 868 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 869 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); 870 expandedShape.append(dimExpansion.begin(), dimExpansion.end()); 871 } 872 return RankedTensorType::get(expandedShape, originalType.getElementType()); 873 } 874 875 /// Returns the reassociation maps to use in the `tensor.expand_shape` 876 /// operation to convert the operands of the original operation to operands of 877 /// the expanded operation. The same method is used to compute the 878 /// `tensor.collapse_shape` used to collapse the result of the expanded 879 /// op to get the value that can replace all uses of the results of the original 880 /// op. 881 static SmallVector<ReassociationIndices> 882 getReassociationForExpansion(AffineMap indexingMap, 883 const ExpansionInfo &expansionInfo) { 884 SmallVector<ReassociationIndices> reassociation; 885 unsigned numReshapeDims = 0; 886 for (AffineExpr expr : indexingMap.getResults()) { 887 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 888 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); 889 SmallVector<int64_t, 2> indices = llvm::to_vector<2>( 890 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims)); 891 reassociation.emplace_back(std::move(indices)); 892 numReshapeDims += numExpandedDims; 893 } 894 return reassociation; 895 } 896 897 /// Update the body of an expanded linalg operation having index semantics. The 898 /// indices of the original operation need to be recovered by linearizing the 899 /// indices of the correspoding dimensions of the expanded operation. For now it 900 /// is assumed that the shapes of the expanded operation needed for 901 /// linearization are static. 902 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, 903 Location loc, Region &fusedRegion, 904 const ExpansionInfo &expansionInfo) { 905 // Replace the original indices by the linearization of the expanded indices. 906 for (IndexOp indexOp : 907 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) { 908 ArrayRef<int64_t> expandedDims = 909 expansionInfo.getExpandedDims(indexOp.dim()); 910 assert(!expandedDims.empty() && "expected valid expansion info"); 911 912 // Skip index operations that are not affected by the expansion. 913 if (expandedDims.size() == 1 && 914 expandedDims.front() == (int64_t)indexOp.dim()) 915 continue; 916 917 // Linearize the expanded indices of the original index dimension. 918 OpBuilder::InsertionGuard guard(rewriter); 919 rewriter.setInsertionPointAfter(indexOp); 920 ArrayRef<int64_t> expandedDimsShape = 921 expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front(); 922 SmallVector<Value> expandedIndices; 923 expandedIndices.reserve(expandedDims.size() - 1); 924 llvm::transform( 925 expandedDims.drop_front(), std::back_inserter(expandedIndices), 926 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); 927 Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front()); 928 for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { 929 assert(!ShapedType::isDynamic(std::get<0>(it))); 930 AffineExpr idx, acc; 931 bindDims(rewriter.getContext(), idx, acc); 932 newIndex = rewriter.create<AffineApplyOp>( 933 indexOp.getLoc(), idx + acc * std::get<0>(it), 934 ValueRange{std::get<1>(it), newIndex}); 935 } 936 rewriter.replaceOp(indexOp, newIndex); 937 } 938 } 939 940 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op 941 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes 942 /// that those conditions have been satisfied. 943 static Optional<SmallVector<Value>> 944 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, 945 OpOperand *fusableOpOperand, 946 PatternRewriter &rewriter) { 947 assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) && 948 "preconditions for fuse operation failed"); 949 // Check if reshape is expanding or collapsing. 950 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp); 951 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp); 952 bool isExpanding = (expandingReshapeOp != nullptr); 953 RankedTensorType expandedType = isExpanding 954 ? expandingReshapeOp.getResultType() 955 : collapsingReshapeOp.getSrcType(); 956 RankedTensorType collapsedType = isExpanding 957 ? expandingReshapeOp.getSrcType() 958 : collapsingReshapeOp.getResultType(); 959 960 ExpansionInfo expansionInfo; 961 if (failed(expansionInfo.compute( 962 genericOp, fusableOpOperand, 963 isExpanding ? expandingReshapeOp.getReassociationMaps() 964 : collapsingReshapeOp.getReassociationMaps(), 965 expandedType.getShape(), collapsedType.getShape(), rewriter))) 966 return llvm::None; 967 968 if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) 969 return llvm::None; 970 971 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( 972 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) { 973 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); 974 })); 975 976 SmallVector<Value> expandedOpOperands; 977 expandedOpOperands.reserve(genericOp.getNumInputs()); 978 for (OpOperand *opOperand : genericOp.getInputOperands()) { 979 if (opOperand == fusableOpOperand) { 980 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() 981 : collapsingReshapeOp.src()); 982 continue; 983 } 984 if (genericOp.isInputTensor(opOperand)) { 985 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 986 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 987 RankedTensorType expandedOperandType = 988 getExpandedType(opOperandType, indexingMap, expansionInfo); 989 if (expandedOperandType != opOperand->get().getType()) { 990 // Reshape the operand to get the right type. 991 SmallVector<ReassociationIndices> reassociation = 992 getReassociationForExpansion(indexingMap, expansionInfo); 993 if (failed(reshapeLikeShapesAreCompatible( 994 [&](const Twine &msg) { 995 return rewriter.notifyMatchFailure(genericOp, msg); 996 }, 997 opOperandType.getShape(), expandedOperandType.getShape(), 998 reassociation, 999 /*isExpandingReshape=*/true))) 1000 return llvm::None; 1001 expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>( 1002 genericOp.getLoc(), expandedOperandType, opOperand->get(), 1003 reassociation)); 1004 continue; 1005 } 1006 } 1007 expandedOpOperands.push_back(opOperand->get()); 1008 } 1009 1010 Location loc = genericOp.getLoc(); 1011 SmallVector<Value> outputs; 1012 for (OpOperand *opOperand : genericOp.getOutputOperands()) { 1013 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 1014 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 1015 RankedTensorType expandedOutputType = 1016 getExpandedType(opOperandType, indexingMap, expansionInfo); 1017 if (expandedOutputType != opOperand->get().getType()) { 1018 SmallVector<ReassociationIndices> reassociation = 1019 getReassociationForExpansion(indexingMap, expansionInfo); 1020 if (failed(reshapeLikeShapesAreCompatible( 1021 [&](const Twine &msg) { 1022 return rewriter.notifyMatchFailure(genericOp, msg); 1023 }, 1024 opOperandType.getShape(), expandedOutputType.getShape(), 1025 reassociation, 1026 /*isExpandingReshape=*/true))) 1027 return llvm::None; 1028 outputs.push_back(rewriter.create<tensor::ExpandShapeOp>( 1029 genericOp.getLoc(), expandedOutputType, opOperand->get(), 1030 reassociation)); 1031 } 1032 } 1033 1034 // The iterator types of the expanded op are all parallel. 1035 SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(), 1036 getParallelIteratorTypeName()); 1037 1038 TypeRange resultTypes = ValueRange(outputs).getTypes(); 1039 auto fusedOp = 1040 rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes, 1041 /*inputs=*/expandedOpOperands, outputs, 1042 expandedOpIndexingMaps, iteratorTypes); 1043 Region &fusedRegion = fusedOp->getRegion(0); 1044 Region &originalRegion = genericOp->getRegion(0); 1045 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); 1046 1047 // Update the index accesses after the expansion. 1048 updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); 1049 1050 // Reshape the result values to their original shape if this is a collapsing 1051 // reshape folded into its consumer. 1052 SmallVector<Value> resultVals; 1053 for (OpResult opResult : genericOp->getOpResults()) { 1054 int64_t resultNumber = opResult.getResultNumber(); 1055 if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { 1056 SmallVector<ReassociationIndices> reassociation = 1057 getReassociationForExpansion( 1058 genericOp.getTiedIndexingMap( 1059 genericOp.getOutputOperand(resultNumber)), 1060 expansionInfo); 1061 resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>( 1062 genericOp.getLoc(), opResult.getType(), 1063 fusedOp->getResult(resultNumber), reassociation)); 1064 } else { 1065 resultVals.push_back(fusedOp->getResult(resultNumber)); 1066 } 1067 } 1068 // Assuming a single result. 1069 return resultVals; 1070 } 1071 1072 namespace { 1073 1074 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op, 1075 /// when the reshape op is collapsing dimensions. The dimensionality of the loop 1076 /// in the consumer is expanded. 1077 class FoldWithProducerReshapeOpByExpansion 1078 : public OpRewritePattern<GenericOp> { 1079 public: 1080 FoldWithProducerReshapeOpByExpansion( 1081 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1082 PatternBenefit benefit = 1) 1083 : OpRewritePattern<GenericOp>(context, benefit), 1084 controlFoldingReshapes(std::move(foldReshapes)) {} 1085 1086 LogicalResult matchAndRewrite(GenericOp genericOp, 1087 PatternRewriter &rewriter) const override { 1088 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 1089 tensor::CollapseShapeOp reshapeOp = 1090 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>(); 1091 if (!reshapeOp) 1092 continue; 1093 // Fold only if 1094 // - The tensor reshape op is folding. 1095 // - All constraints of fusing with reshape by expansion are met. 1096 if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || 1097 (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) 1098 continue; 1099 1100 Optional<SmallVector<Value>> replacementValues = 1101 fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); 1102 if (!replacementValues) 1103 return failure(); 1104 rewriter.replaceOp(genericOp, replacementValues.getValue()); 1105 return success(); 1106 } 1107 return failure(); 1108 } 1109 1110 private: 1111 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1112 }; 1113 1114 /// Pattern to fold a tensor_expand_shape op with its producer generic op 1115 /// by expanding the dimensionality of the loop in the producer op. 1116 struct FoldReshapeWithGenericOpByExpansion 1117 : public OpRewritePattern<tensor::ExpandShapeOp> { 1118 1119 FoldReshapeWithGenericOpByExpansion( 1120 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1121 PatternBenefit benefit = 1) 1122 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), 1123 controlFoldingReshapes(std::move(foldReshapes)) {} 1124 1125 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, 1126 PatternRewriter &rewriter) const override { 1127 // Fold only if all constraints of fusing with reshape by expansion are met. 1128 GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>(); 1129 if (!producer || producer.getNumOutputs() != 1 || 1130 !isFusableWithReshapeByDimExpansion(producer, 1131 producer.getOutputOperand(0)) || 1132 !controlFoldingReshapes(producer->getResult(0), 1133 reshapeOp->getOpOperand(0))) 1134 return failure(); 1135 Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion( 1136 producer, reshapeOp, producer.getOutputOperand(0), rewriter); 1137 if (!replacementValues) 1138 return failure(); 1139 rewriter.replaceOp(reshapeOp, replacementValues.getValue()); 1140 return success(); 1141 } 1142 1143 private: 1144 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1145 }; 1146 } // namespace 1147 1148 //===---------------------------------------------------------------------===// 1149 // Methods and patterns to fuse reshape with linalg.generic operations by 1150 // contraction of dimensions. 1151 //===---------------------------------------------------------------------===// 1152 1153 /// For an `indexingMap` that is a projected permutation, if the range is to be 1154 /// collapsed using the given `reassociation`, get the reassociation in the 1155 /// domain that would keep the map a projected permutation. 1156 static SmallVector<ReassociationIndices> 1157 getDomainReassociation(AffineMap indexingMap, 1158 ArrayRef<ReassociationIndices> rangeReassociation) { 1159 assert(indexingMap.isProjectedPermutation() && 1160 "expected projected permutation map"); 1161 unsigned counter = 0; 1162 SmallVector<ReassociationIndices> domainReassociation; 1163 llvm::SmallDenseSet<unsigned, 4> processedDomainDims; 1164 // Iterate over the reassociation indices. 1165 for (ReassociationIndicesRef foldedRangeDims : rangeReassociation) { 1166 ReassociationIndices foldedDomainDims; 1167 for (auto rangeDim : foldedRangeDims) { 1168 (void)rangeDim; 1169 AffineDimExpr dimExpr = 1170 indexingMap.getResult(counter++).cast<AffineDimExpr>(); 1171 foldedDomainDims.push_back(dimExpr.getPosition()); 1172 processedDomainDims.insert(dimExpr.getPosition()); 1173 } 1174 domainReassociation.emplace_back(std::move(foldedDomainDims)); 1175 } 1176 // Fill in the missing domain dims. 1177 for (auto dim : llvm::seq<unsigned>(0, indexingMap.getNumDims())) { 1178 if (processedDomainDims.count(dim)) 1179 continue; 1180 ReassociationIndices vec = {dim}; 1181 domainReassociation.emplace_back(std::move(vec)); 1182 } 1183 1184 // Sort the reassociation using the first dimension of the folded range to 1185 // not create unnecessary transposes. 1186 llvm::sort(domainReassociation, 1187 [](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) { 1188 return lhs[0] < rhs[0]; 1189 }); 1190 return domainReassociation; 1191 } 1192 1193 /// For a given `dimSequence`, check if the sequence is conserved in the 1194 /// `indexingMap`. `indexingMap` is expected to be a projected permutation. 1195 /// Non-existence of the sequence returns true as well. 1196 static bool isDimSequencePreserved(AffineMap indexingMap, 1197 ReassociationIndicesRef dimSequence) { 1198 assert(!dimSequence.empty() && 1199 "expected non-empty list for dimension sequence"); 1200 assert(indexingMap.isProjectedPermutation() && 1201 "expected indexing map to be projected permutation"); 1202 1203 llvm::SmallDenseSet<unsigned, 4> sequenceElements; 1204 sequenceElements.insert(dimSequence.begin(), dimSequence.end()); 1205 1206 unsigned dimSequenceStart = dimSequence[0]; 1207 for (auto expr : enumerate(indexingMap.getResults())) { 1208 unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition(); 1209 // 1. Check if this start of the sequence. 1210 if (dimInMapStart == dimSequenceStart) { 1211 if (expr.index() + dimSequence.size() > indexingMap.getNumResults()) 1212 return false; 1213 // 1a. Check if sequence is preserved. 1214 for (auto dimInSequence : enumerate(dimSequence)) { 1215 unsigned dimInMap = 1216 indexingMap.getResult(expr.index() + dimInSequence.index()) 1217 .cast<AffineDimExpr>() 1218 .getPosition(); 1219 if (dimInMap != dimInSequence.value()) 1220 return false; 1221 } 1222 // Found the sequence. Projected permutation 1223 // enforces that all AffineDimExprs in the result are unique, so no 1224 // further checks are needed. 1225 return true; 1226 } 1227 // 2. If position in the expr (which is of type AffineDimExpr) is part 1228 // of sequence, return false here. This implies the entire sequence does not 1229 // exist in the indexing map. 1230 if (sequenceElements.count(dimInMapStart)) 1231 return false; 1232 } 1233 // 3. No element of sequence found. Return true. 1234 return true; 1235 } 1236 1237 // Check if a generic op can be fused along an operand by collapsing dimensions. 1238 static bool isFusableWithReshapeByDimCollapse( 1239 GenericOp genericOp, OpOperand *fusableOperand, 1240 ArrayRef<ReassociationIndices> reassociation) { 1241 // Some basic checks for this fusion to be valid. 1242 if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1) 1243 return false; 1244 1245 if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) { 1246 return map.isProjectedPermutation(); 1247 })) { 1248 return false; 1249 } 1250 1251 // Get the reassociation for the iteration space. 1252 SmallVector<ReassociationIndices> iterationReassociation = 1253 getDomainReassociation(genericOp.getTiedIndexingMap(fusableOperand), 1254 reassociation); 1255 if (iterationReassociation.empty()) { 1256 // If the domain reassociation indices is empty, then this is a scalar op. 1257 // Nothing to do. 1258 return false; 1259 } 1260 1261 auto iteratorTypes = genericOp.iterator_types().getValue(); 1262 ArrayRef<Attribute> iteratorTypesRef(iteratorTypes); 1263 for (ReassociationIndicesRef foldedIterDims : iterationReassociation) { 1264 // Check that for all indexing maps, the folded dimensions sequence is 1265 // preserved. 1266 if (!llvm::all_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) { 1267 return isDimSequencePreserved(indexingMap, foldedIterDims); 1268 })) 1269 return false; 1270 unsigned startDim = foldedIterDims[0]; 1271 ArrayRef<Attribute> foldedIteratorTypes = 1272 iteratorTypesRef.drop_front(startDim).take_front(foldedIterDims.size()); 1273 // Check that all folded iterator types are either all parallel, or all 1274 // reduction. 1275 if (!llvm::all_of( 1276 foldedIteratorTypes, 1277 [](Attribute attr) { return isParallelIterator(attr); }) && 1278 !llvm::all_of(foldedIteratorTypes, 1279 [](Attribute attr) { return isReductionIterator(attr); })) 1280 return false; 1281 } 1282 return true; 1283 } 1284 1285 /// Helper class to carry state while collapsing the `linalg.generic` op. 1286 namespace { 1287 class CollapsingInfo { 1288 public: 1289 CollapsingInfo(SmallVector<ReassociationIndices> &&reassociation) { 1290 iterationReassociation = std::move(reassociation); 1291 for (auto foldedIterDims : enumerate(iterationReassociation)) { 1292 foldedDimStartToSequenceMap[foldedIterDims.value()[0]] = 1293 foldedIterDims.index(); 1294 } 1295 } 1296 1297 // Returns the iteration space reassociation. 1298 ArrayRef<ReassociationIndices> getReassociationIndices() { 1299 return iterationReassociation; 1300 } 1301 1302 // Returns true if the given dimension is the start of a sequence of folded 1303 // dimensions. 1304 bool isDimStartOfFoldedDims(unsigned dim) { 1305 return foldedDimStartToSequenceMap.count(dim); 1306 } 1307 1308 // Return the folded dimensions starting at `dim`. 1309 ReassociationIndicesRef getFoldedDimsStartingAt(unsigned dim) { 1310 assert(foldedDimStartToSequenceMap.count(dim) && 1311 "invalid start dim of folded dim " 1312 "sequence"); 1313 return iterationReassociation[foldedDimStartToSequenceMap[dim]]; 1314 } 1315 1316 // For a dim in the original op, return the dim in the collapsed op, that it 1317 // is mapped to. Expectes `dim` to be start of a folded dimension sequence. 1318 unsigned getDimInCollapsedOpForStartOfFoldedDims(unsigned dim) { 1319 assert(foldedDimStartToSequenceMap.count(dim) && 1320 "invalid start dim of folded dim sequence"); 1321 return foldedDimStartToSequenceMap[dim]; 1322 } 1323 1324 private: 1325 /// Reassociation describing the folded iteration space dimensions. 1326 SmallVector<ReassociationIndices> iterationReassociation; 1327 1328 /// Map from the starting dimensions of the folded dimension sequences to 1329 /// their index in `iterationReassociation`. 1330 llvm::DenseMap<unsigned, unsigned> foldedDimStartToSequenceMap; 1331 }; 1332 } // namespace 1333 1334 /// Get the iterator types for the collapsed operation given the original 1335 /// iterator types and collapsed dimensions. 1336 static SmallVector<StringRef> 1337 getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes, 1338 CollapsingInfo &collapsingInfo) { 1339 SmallVector<StringRef> collapsedIteratorTypes; 1340 for (ReassociationIndicesRef foldedIterDims : 1341 collapsingInfo.getReassociationIndices()) { 1342 assert(!foldedIterDims.empty() && 1343 "reassociation indices expected to have non-empty sets"); 1344 // Just pick the iterator type of the first folded dim. Pre-condition checks 1345 // expected to have checked that iterator types of all folded dimensions are 1346 // the same. 1347 collapsedIteratorTypes.push_back( 1348 iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue()); 1349 } 1350 return collapsedIteratorTypes; 1351 } 1352 1353 /// Compute the indexing map in the collapsed op that corresponds to the given 1354 /// `indexingMap` of the original operation. 1355 static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, 1356 CollapsingInfo &collapsingInfo) { 1357 MLIRContext *context = indexingMap.getContext(); 1358 assert(indexingMap.isProjectedPermutation() && 1359 "expected indexing map to be projected permutation"); 1360 SmallVector<AffineExpr> resultExprs; 1361 for (auto expr : indexingMap.getResults()) { 1362 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 1363 if (collapsingInfo.isDimStartOfFoldedDims(dim)) { 1364 resultExprs.push_back(getAffineDimExpr( 1365 collapsingInfo.getDimInCollapsedOpForStartOfFoldedDims(dim), 1366 context)); 1367 } 1368 } 1369 return AffineMap::get(collapsingInfo.getReassociationIndices().size(), 0, 1370 resultExprs, context); 1371 } 1372 1373 /// Return the `reassociation` indices to use to collapse the operand when the 1374 /// iteration space of a generic op is collapsed. 1375 static SmallVector<ReassociationIndices> 1376 getOperandReassociation(AffineMap indexingMap, CollapsingInfo &collapsingInfo) { 1377 unsigned counter = 0; 1378 SmallVector<ReassociationIndices> operandReassociation; 1379 for (auto expr : indexingMap.getResults()) { 1380 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 1381 if (collapsingInfo.isDimStartOfFoldedDims(dim)) { 1382 unsigned numFoldedDims = 1383 collapsingInfo.getFoldedDimsStartingAt(dim).size(); 1384 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims); 1385 operandReassociation.emplace_back(range.begin(), range.end()); 1386 counter += numFoldedDims; 1387 } 1388 } 1389 return operandReassociation; 1390 } 1391 1392 /// Get the new value to use for a given `OpOperand` in the collapsed operation. 1393 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, 1394 OpOperand *opOperand, 1395 CollapsingInfo &collapsingInfo, 1396 OpBuilder &builder) { 1397 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 1398 SmallVector<ReassociationIndices> operandReassociation = 1399 getOperandReassociation(indexingMap, collapsingInfo); 1400 1401 // If the number of entries in the reassocation for the operand is same as the 1402 // number of results of the indexing map, then nothing to do for this operand. 1403 Value operand = opOperand->get(); 1404 if (operandReassociation.size() == indexingMap.getNumResults()) 1405 return operand; 1406 1407 // Insert a reshape to collapse the dimensions. 1408 auto reshapeOp = builder.create<tensor::CollapseShapeOp>( 1409 loc, operand, operandReassociation); 1410 return reshapeOp.getResult(); 1411 } 1412 1413 /// Modify the `linalg.index` operations in the original generic op, to its 1414 /// value in the collapsed operation. 1415 void generateCollapsedIndexingRegion(Location loc, Block *block, 1416 CollapsingInfo &collapsingInfo, 1417 ValueRange loopRange, 1418 PatternRewriter &rewriter) { 1419 OpBuilder::InsertionGuard g(rewriter); 1420 rewriter.setInsertionPointToStart(block); 1421 1422 // Collect all the original index ops. 1423 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>()); 1424 1425 // For each folded dimension list resolve the original induction variable 1426 // values in terms of the folded dimension induction variable. 1427 // i_{folded} = (i_0 * d1 + i1) * d2 + i2. 1428 // can be inverted to 1429 // i2 = i_{folded} % d2 1430 // i1 = (i_{folded} / d2) % d1 1431 // i0 = i_{folded} / (d1 * d2) 1432 llvm::DenseMap<unsigned, Value> indexReplacementVals; 1433 for (auto &foldedDims : enumerate(collapsingInfo.getReassociationIndices())) { 1434 ReassociationIndicesRef foldedDimsRef(foldedDims.value()); 1435 Value newIndexVal = 1436 rewriter.create<linalg::IndexOp>(loc, foldedDims.index()); 1437 for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) { 1438 indexReplacementVals[dim] = 1439 rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]); 1440 newIndexVal = 1441 rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]); 1442 } 1443 indexReplacementVals[foldedDims.value().front()] = newIndexVal; 1444 } 1445 1446 for (auto indexOp : indexOps) { 1447 auto dim = indexOp.dim(); 1448 rewriter.replaceOp(indexOp, indexReplacementVals[dim]); 1449 } 1450 } 1451 1452 /// Implementation of fusion with reshape operation by collapsing dimensions. 1453 static Optional<SmallVector<Value>> 1454 fuseWithReshapeByCollapsing(GenericOp genericOp, Operation *reshapeOp, 1455 OpOperand *fusableOpOperand, 1456 PatternRewriter &rewriter) { 1457 SmallVector<ReassociationIndices> reassociation = 1458 isa<tensor::CollapseShapeOp>(reshapeOp) 1459 ? cast<tensor::CollapseShapeOp>(reshapeOp).getReassociationIndices() 1460 : cast<tensor::ExpandShapeOp>(reshapeOp).getReassociationIndices(); 1461 assert(isFusableWithReshapeByDimCollapse(genericOp, fusableOpOperand, 1462 reassociation) && 1463 "preconditions for fusing with reshape by collapse failed"); 1464 1465 CollapsingInfo collapsingInfo(getDomainReassociation( 1466 genericOp.getTiedIndexingMap(fusableOpOperand), reassociation)); 1467 // Check for trivial no transformation cases. In that case return nothing. 1468 if (collapsingInfo.getReassociationIndices().size() == 1469 genericOp.getNumLoops()) 1470 return llvm::None; 1471 1472 // Get the iterator types for the operand. 1473 SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes( 1474 genericOp.iterator_types().getValue(), collapsingInfo); 1475 1476 // Get the indexing maps. 1477 auto indexingMaps = llvm::to_vector( 1478 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap map) { 1479 return getCollapsedOpIndexingMap(map, collapsingInfo); 1480 })); 1481 1482 Location loc = 1483 rewriter.getFusedLoc({genericOp->getLoc(), reshapeOp->getLoc()}); 1484 1485 // Get the input operands. 1486 auto inputOperands = llvm::to_vector( 1487 llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) { 1488 return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo, 1489 rewriter); 1490 })); 1491 1492 // Get the output operands and result types. 1493 SmallVector<Type> resultTypes; 1494 SmallVector<Value> outputOperands; 1495 resultTypes.reserve(genericOp.getNumOutputs()); 1496 outputOperands.reserve(genericOp.getNumOutputs()); 1497 for (OpOperand *output : genericOp.getOutputOperands()) { 1498 Value newOutput = 1499 getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter); 1500 outputOperands.push_back(newOutput); 1501 resultTypes.push_back(newOutput.getType()); 1502 } 1503 1504 // Create the generic op. 1505 auto collapsedGenericOp = rewriter.create<linalg::GenericOp>( 1506 loc, resultTypes, inputOperands, outputOperands, indexingMaps, 1507 iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); 1508 Block *origOpBlock = &genericOp->getRegion(0).front(); 1509 Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front(); 1510 rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, 1511 collapsedOpBlock->getArguments()); 1512 1513 if (collapsedGenericOp.hasIndexSemantics()) { 1514 // Collect the loop range of the generic op. 1515 OpBuilder::InsertionGuard g(rewriter); 1516 rewriter.setInsertionPoint(collapsedGenericOp); 1517 SmallVector<Range> loopRanges = 1518 cast<LinalgOp>(genericOp.getOperation()) 1519 .createLoopRanges(rewriter, genericOp.getLoc()); 1520 assert(llvm::all_of(loopRanges, 1521 [](Range range) { 1522 return matchPattern(range.offset, m_Zero()) && 1523 matchPattern(range.stride, m_One()); 1524 }) && 1525 "expected all loop ranges to have zero start and unit stride"); 1526 SmallVector<Value> loopBound = llvm::to_vector( 1527 llvm::map_range(loopRanges, [](Range range) { return range.size; })); 1528 generateCollapsedIndexingRegion(loc, 1529 &collapsedGenericOp->getRegion(0).front(), 1530 collapsingInfo, loopBound, rewriter); 1531 } 1532 1533 // Insert expanding reshape for the result to get back the original result 1534 // type. 1535 SmallVector<Value> results; 1536 for (auto originalResult : llvm::enumerate(genericOp->getResults())) { 1537 Value collapsedOpResult = 1538 collapsedGenericOp->getResult(originalResult.index()); 1539 auto originalResultType = 1540 originalResult.value().getType().cast<ShapedType>(); 1541 auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>(); 1542 if (collapsedOpResultType.getRank() != originalResultType.getRank()) { 1543 AffineMap indexingMap = 1544 genericOp.getTiedIndexingMapForResult(originalResult.value()); 1545 SmallVector<ReassociationIndices> reassociation = 1546 getOperandReassociation(indexingMap, collapsingInfo); 1547 Value result = rewriter.create<tensor::ExpandShapeOp>( 1548 loc, originalResultType, collapsedOpResult, reassociation); 1549 results.push_back(result); 1550 } else { 1551 results.push_back(collapsedOpResult); 1552 } 1553 } 1554 return results; 1555 } 1556 1557 namespace { 1558 1559 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by 1560 /// contracting dimensions of the loop. 1561 class FoldWithProducerReshapeOpByCollapsing 1562 : public OpRewritePattern<GenericOp> { 1563 public: 1564 FoldWithProducerReshapeOpByCollapsing( 1565 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1566 PatternBenefit benefit = 1) 1567 : OpRewritePattern<GenericOp>(context, benefit), 1568 controlFoldingReshapes(std::move(foldReshapes)) {} 1569 1570 LogicalResult matchAndRewrite(GenericOp genericOp, 1571 PatternRewriter &rewriter) const override { 1572 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 1573 tensor::ExpandShapeOp reshapeOp = 1574 opOperand->get().getDefiningOp<tensor::ExpandShapeOp>(); 1575 if (!reshapeOp) 1576 continue; 1577 1578 if (!isFusableWithReshapeByDimCollapse( 1579 genericOp, opOperand, reshapeOp.getReassociationIndices()) || 1580 !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) { 1581 continue; 1582 } 1583 1584 Optional<SmallVector<Value>> replacements = fuseWithReshapeByCollapsing( 1585 genericOp, reshapeOp, opOperand, rewriter); 1586 if (!replacements) { 1587 return rewriter.notifyMatchFailure( 1588 genericOp, "failed to do the fusion by collapsing transformation"); 1589 } 1590 1591 rewriter.replaceOp(genericOp, replacements.getValue()); 1592 return success(); 1593 } 1594 return failure(); 1595 } 1596 1597 private: 1598 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1599 }; 1600 } // namespace 1601 1602 //===---------------------------------------------------------------------===// 1603 // Methods and patterns to convert tensor.expand_shape -> linalg.generic 1604 // into linalg.generic -> tensor.expand_shape, i.e. push the reshape down. 1605 //===---------------------------------------------------------------------===// 1606 1607 // TODO(ravishankarm): This pattern is to be deprecated in favor of fusion by 1608 // collapsing that provides a more general functionality. This pattern is very 1609 // specific to a particular use case. The fusion by collapsing can provide the 1610 // same control to clients using the control function there. 1611 1612 static SmallVector<ReassociationIndices> 1613 getReassociationIndices(ArrayRef<AffineMap> maps) { 1614 SmallVector<ReassociationIndices> reassociation; 1615 for (AffineMap map : maps) { 1616 ReassociationIndices indices; 1617 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 1618 unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition(); 1619 indices.push_back(pos); 1620 } 1621 reassociation.push_back(indices); 1622 } 1623 return reassociation; 1624 } 1625 1626 namespace { 1627 /// Pattern to move rank reducing reshape after an elementwise linalg generic 1628 /// op. This is useful to expose more fusion opportunities between named ops and 1629 /// generic ops. This can only be done if there is no broadcast or permuation 1630 /// within the dimensions we need to merge. 1631 /// 1632 /// For example, 1633 /// 1634 /// %0 = tensor.expand_shape %A [[0, 1], [2]] 1635 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 1636 /// %2 = linalg.generic {indexing_maps = [ 1637 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1638 /// affine_map<(d0, d1, d2) -> (d2)>, 1639 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = 1640 /// ["parallel", "parallel", "parallel"]} { 1641 /// } -> tensor<112x112x16xf32> 1642 /// 1643 /// into 1644 /// 1645 /// %2 = linalg.generic {indexing_maps = [ 1646 /// affine_map<(d0, d1) -> (d0, d1)>, 1647 /// affine_map<(d0, d1) -> (d1)>, 1648 /// affine_map<(d0, d1) -> (d0, d1)>], 1649 /// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 1650 /// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) { 1651 /// } -> tensor<12544x16xf32> 1652 /// %3 = tensor.expand_shape %2 [[0, 1], [2]] 1653 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 1654 struct PushExpandingReshape : public OpRewritePattern<GenericOp> { 1655 using OpRewritePattern<GenericOp>::OpRewritePattern; 1656 1657 LogicalResult matchAndRewrite(GenericOp genericOp, 1658 PatternRewriter &rewriter) const override { 1659 // Only apply to elementwise linalg on tensor. 1660 if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() || 1661 genericOp.getNumParallelLoops() != genericOp.getNumLoops()) 1662 return failure(); 1663 // Only support identity output maps. It could be extended to permuations if 1664 // needed. 1665 if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) { 1666 return !genericOp.getTiedIndexingMap(opOperand).isIdentity(); 1667 })) 1668 return failure(); 1669 int64_t destRank = genericOp.getNumParallelLoops(); 1670 SmallVector<Value> newOperands = genericOp.getInputOperands(); 1671 tensor::ExpandShapeOp reshapeFound; 1672 // 1. Look for tensor_expand_shape operands and figure out save the 1673 // dimensions merged. 1674 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 1675 for (const auto &en : llvm::enumerate(inputOperands)) { 1676 auto reshapeOp = 1677 en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>(); 1678 if (!reshapeOp) 1679 continue; 1680 // TODO: We could support non-identity map as long as the merged 1681 // dimensions are still contiguous. 1682 if (!genericOp.getTiedIndexingMap(en.value()).isIdentity()) 1683 continue; 1684 if (reshapeFound) { 1685 // Only support a second reshape op if it has the same reassociate maps. 1686 if (reshapeFound.getReassociationMaps() == 1687 reshapeOp.getReassociationMaps()) 1688 newOperands[en.index()] = reshapeOp.src(); 1689 continue; 1690 } 1691 reshapeFound = reshapeOp; 1692 newOperands[en.index()] = reshapeOp.src(); 1693 } 1694 if (!reshapeFound) 1695 return failure(); 1696 1697 // Calculate the reassociation indices and rassociated reverse map. 1698 SmallVector<ReassociationIndices> reassociation = 1699 getReassociationIndices(reshapeFound.getReassociationMaps()); 1700 SmallVector<unsigned> remap(destRank); 1701 for (auto &indices : llvm::enumerate(reassociation)) { 1702 for (int64_t index : indices.value()) { 1703 remap[index] = indices.index(); 1704 } 1705 } 1706 // 2. Verify that we can merge the dimensions in the linalg and that we 1707 // don't need to create new reshapes operands. Inserting new reshape 1708 // operands would defeat the purpose of the transformation. 1709 for (const auto &en : llvm::enumerate(inputOperands)) { 1710 if (en.value()->get() == newOperands[en.index()]) { 1711 AffineMap map = genericOp.getTiedIndexingMap(en.value()); 1712 for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { 1713 if (reassociation[remap[map.getDimPosition(i)]].size() > 1) 1714 return failure(); 1715 } 1716 } 1717 } 1718 1719 // 3. Calculate the affine map remapping and the reassociation to apply to 1720 // output tensors. 1721 SmallVector<AffineMap> newMaps; 1722 unsigned newRank = reassociation.size(); 1723 for (auto map : genericOp.getIndexingMaps()) { 1724 SmallVector<AffineExpr> newExprs; 1725 for (auto expr : map.getResults()) { 1726 unsigned position = expr.template cast<AffineDimExpr>().getPosition(); 1727 // Skip dimension merged except for the last of the group. 1728 if (reassociation[remap[position]].back() == position) { 1729 newExprs.push_back( 1730 getAffineDimExpr(remap[position], genericOp.getContext())); 1731 } 1732 } 1733 newMaps.push_back( 1734 AffineMap::get(newRank, 0, newExprs, genericOp.getContext())); 1735 } 1736 1737 // 4. Reshape the output tensors. 1738 SmallVector<Value> newOutputs; 1739 SmallVector<Type> newOutputTypes; 1740 for (auto output : genericOp.outputs()) { 1741 auto newOutputType = RankedTensorType::get( 1742 reshapeFound.getSrcType().getShape(), 1743 output.getType().template cast<RankedTensorType>().getElementType()); 1744 Value newOutput = rewriter.create<tensor::CollapseShapeOp>( 1745 genericOp->getLoc(), newOutputType, output, reassociation); 1746 newOutputTypes.push_back(newOutputType); 1747 newOutputs.push_back(newOutput); 1748 } 1749 // 5. Create a new generic op with lowerer rank. 1750 SmallVector<StringRef> iteratorTypes(newRank, 1751 getParallelIteratorTypeName()); 1752 auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes, 1753 newOperands, newOutputs, newMaps, 1754 iteratorTypes); 1755 rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), 1756 newOp.region().begin()); 1757 // 6. Reshape the so that the type matches the uses. 1758 SmallVector<Value> newResults; 1759 for (const auto &result : llvm::enumerate(newOp->getResults())) { 1760 newResults.push_back(rewriter.create<tensor::ExpandShapeOp>( 1761 genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()], 1762 result.value(), reassociation)); 1763 } 1764 rewriter.replaceOp(genericOp, newResults); 1765 return success(); 1766 } 1767 }; 1768 } // namespace 1769 1770 //===---------------------------------------------------------------------===// 1771 // Methods and patterns that fuse constants with linalg.generic operations. 1772 //===---------------------------------------------------------------------===// 1773 1774 namespace { 1775 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not 1776 /// handle cases where the constant is not single-valued. 1777 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> { 1778 public: 1779 FoldScalarOrSplatConstant(MLIRContext *context, 1780 ControlElementwiseOpsFusionFn &fun, 1781 PatternBenefit benefit = 1) 1782 : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {} 1783 1784 LogicalResult matchAndRewrite(GenericOp genericOp, 1785 PatternRewriter &rewriter) const override { 1786 if (!genericOp.hasTensorSemantics()) 1787 return failure(); 1788 for (OpOperand *opOperand : genericOp.getInputOperands()) { 1789 Operation *def = opOperand->get().getDefiningOp(); 1790 Attribute constantAttr; 1791 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { 1792 { 1793 DenseElementsAttr splatAttr; 1794 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) && 1795 splatAttr.isSplat() && 1796 splatAttr.getType().getElementType().isIntOrFloat()) { 1797 constantAttr = splatAttr.getSplatValue<Attribute>(); 1798 return true; 1799 } 1800 } 1801 { 1802 IntegerAttr intAttr; 1803 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) { 1804 constantAttr = intAttr; 1805 return true; 1806 } 1807 } 1808 { 1809 FloatAttr floatAttr; 1810 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) { 1811 constantAttr = floatAttr; 1812 return true; 1813 } 1814 } 1815 return false; 1816 }; 1817 1818 auto resultValue = opOperand->get().dyn_cast<OpResult>(); 1819 if (!def || !resultValue || !isScalarOrSplatConstantOp(def) || 1820 !controlFn(resultValue, *opOperand)) 1821 continue; 1822 1823 // The operands and the indexing_maps of the fused operation the same as 1824 // the operands and indexing_maps of the generic operations with the 1825 // values at the constant index dropped. 1826 SmallVector<AffineMap> fusedIndexMaps; 1827 SmallVector<Value> fusedOperands; 1828 SmallVector<Location> fusedLocs{genericOp.getLoc()}; 1829 fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); 1830 fusedOperands.reserve(genericOp.getNumInputs()); 1831 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); 1832 for (OpOperand *inputOperand : genericOp.getInputOperands()) { 1833 if (inputOperand == opOperand) 1834 continue; 1835 Value inputValue = inputOperand->get(); 1836 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); 1837 fusedOperands.push_back(inputValue); 1838 fusedLocs.push_back(inputValue.getLoc()); 1839 } 1840 for (OpOperand *outputOperand : genericOp.getOutputOperands()) 1841 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); 1842 1843 // Check if the operation shapes to loops map is computable. 1844 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1845 return rewriter.notifyMatchFailure( 1846 genericOp, "fused op loop bound computation failed"); 1847 } 1848 1849 // Create a constant scalar value from the splat constant. 1850 Value scalarConstant = rewriter.create<arith::ConstantOp>( 1851 def->getLoc(), constantAttr, constantAttr.getType()); 1852 1853 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 1854 auto fusedOp = rewriter.create<GenericOp>( 1855 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), 1856 /*inputs=*/fusedOperands, 1857 /*outputs=*/outputOperands, 1858 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1859 genericOp.iterator_types(), 1860 /*doc=*/nullptr, 1861 /*library_call=*/nullptr); 1862 1863 // Map the block argument corresponding to the replaced argument with the 1864 // scalar constant. 1865 Region ®ion = genericOp->getRegion(0); 1866 Block &entryBlock = *region.begin(); 1867 BlockAndValueMapping mapping; 1868 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), 1869 scalarConstant); 1870 Region &fusedRegion = fusedOp->getRegion(0); 1871 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), 1872 mapping); 1873 rewriter.replaceOp(genericOp, fusedOp->getResults()); 1874 return success(); 1875 } 1876 return failure(); 1877 } 1878 1879 private: 1880 ControlElementwiseOpsFusionFn controlFn; 1881 }; 1882 1883 /// Base class for constant folding linalg.generic ops with N inputs, 1 output, 1884 /// and permutation indexing maps. 1885 /// 1886 /// `ConcreteType` should provide methods with signatures 1887 /// 1888 /// ```c++ 1889 /// bool matchIndexingMaps(GenericOp genericOp) const; 1890 /// RegionComputationFn getRegionComputeFn(GenericOp) const; 1891 /// ``` 1892 /// 1893 /// The latter inspects the region and returns the computation inside as a 1894 /// functor. The functor will be invoked with constant elements for all inputs 1895 /// and should return the corresponding computea constant element for output. 1896 template <typename ConcreteType> 1897 class FoldConstantBase : public OpRewritePattern<GenericOp> { 1898 public: 1899 struct APIntOrFloat { 1900 Optional<APInt> apInt; 1901 Optional<APFloat> apFloat; 1902 }; 1903 struct APIntOrFloatArray { 1904 SmallVector<APInt> apInts; 1905 SmallVector<APFloat> apFloats; 1906 }; 1907 using RegionComputationFn = 1908 std::function<APIntOrFloat(const APIntOrFloatArray &)>; 1909 1910 FoldConstantBase(MLIRContext *context, 1911 const ControlElementwiseOpsFusionFn &controlFn, 1912 PatternBenefit benefit = 1) 1913 : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {} 1914 1915 LogicalResult matchAndRewrite(GenericOp genericOp, 1916 PatternRewriter &rewriter) const override { 1917 if (genericOp.hasBufferSemantics()) 1918 return failure(); 1919 1920 // Only support ops generating one output for now. 1921 if (genericOp.getNumOutputs() != 1) 1922 return failure(); 1923 1924 auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>(); 1925 // Require the output types to be static give we are generating constants. 1926 if (!outputType || !outputType.hasStaticShape()) 1927 return failure(); 1928 1929 if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) { 1930 return operand->get().getType().isa<ShapedType>(); 1931 })) 1932 return failure(); 1933 1934 // Make sure all element types are the same. 1935 auto getOperandElementType = [](OpOperand *operand) { 1936 return operand->get().getType().cast<ShapedType>().getElementType(); 1937 }; 1938 if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(), 1939 getOperandElementType))) 1940 return failure(); 1941 1942 // We can only handle the case where we have int/float elements. 1943 auto elementType = outputType.getElementType(); 1944 if (!elementType.isIntOrFloat()) 1945 return failure(); 1946 1947 // Require all indexing maps to be permutations for now. This is common and 1948 // it simplifies input/output access greatly: we can do the data shuffling 1949 // entirely in the compiler, without needing to turn all indices into 1950 // Values, and then do affine apply on them, and then match back the 1951 // constant again. 1952 if (!llvm::all_of(genericOp.getIndexingMaps(), 1953 [](AffineMap map) { return map.isPermutation(); })) 1954 return failure(); 1955 1956 for (OpOperand *operand : genericOp.getOutputOperands()) { 1957 if (genericOp.payloadUsesValueFromOperand(operand)) 1958 return failure(); 1959 } 1960 1961 // Further check the indexing maps are okay for the ConcreteType. 1962 if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp)) 1963 return failure(); 1964 1965 // Defer to the concrete type to check the region and discover the 1966 // computation inside. 1967 RegionComputationFn computeFn = 1968 static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp); 1969 if (!computeFn) 1970 return failure(); 1971 1972 // All inputs should be constants. 1973 int numInputs = genericOp.getNumInputs(); 1974 SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs); 1975 for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) { 1976 if (!matchPattern(operand.value()->get(), 1977 m_Constant(&inputValues[operand.index()]))) 1978 return failure(); 1979 } 1980 1981 // Identified this as a potential candidate for folding. Now check the 1982 // policy to see whether we are allowed to proceed. 1983 for (int i = 0; i < numInputs; ++i) { 1984 OpOperand *consumer = genericOp.getInputOperand(i); 1985 OpResult producer = consumer->get().cast<OpResult>(); 1986 if (!controlFn(producer, *consumer)) 1987 return failure(); 1988 } 1989 1990 auto linalgOp = cast<LinalgOp>(genericOp.getOperation()); 1991 SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes(); 1992 int64_t numElements = outputType.getNumElements(); 1993 1994 // Use APInt/APFloat instead of Attribute here for constructing the output. 1995 // This helps to avoid blowing up compiler memory usage: Attributes would 1996 // unify the following cases but they have lifetime as the MLIRContext. 1997 SmallVector<APInt> intOutputValues; 1998 SmallVector<APFloat> fpOutputValues; 1999 if (elementType.template isa<FloatType>()) 2000 fpOutputValues.resize(numElements, APFloat(0.f)); 2001 else 2002 intOutputValues.resize(numElements); 2003 2004 // Return the constant dim positions from the given permutation map. 2005 auto getDimPositions = [](AffineMap map) { 2006 SmallVector<unsigned> dims; 2007 dims.reserve(map.getNumResults()); 2008 for (AffineExpr result : map.getResults()) { 2009 dims.push_back(result.cast<AffineDimExpr>().getPosition()); 2010 } 2011 return dims; 2012 }; 2013 2014 SmallVector<SmallVector<unsigned>> inputDims; 2015 for (int i = 0; i < numInputs; ++i) 2016 inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i])); 2017 auto outputDims = getDimPositions(genericOp.getIndexingMaps().back()); 2018 auto outputShape = outputType.getShape(); 2019 2020 // Allocate small vectors for index delinearization. Initial values do not 2021 // matter here as they will be overwritten later. 2022 SmallVector<uint64_t> indices(loopBounds.size(), 0); 2023 SmallVector<uint64_t> dstIndices(loopBounds.size(), 0); 2024 SmallVector<SmallVector<uint64_t>> srcIndices( 2025 numInputs, SmallVector<uint64_t>(loopBounds.size(), 0)); 2026 SmallVector<uint64_t> srcLinearIndices(numInputs, 0); 2027 uint64_t dstLinearIndex = 0; 2028 2029 // Allocate spaces for compute function inputs. Initial values do not matter 2030 // here as they will be overwritten later. 2031 APIntOrFloatArray computeFnInputs; 2032 2033 auto inputShapes = llvm::to_vector<4>( 2034 llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { 2035 return operand->get().getType().cast<ShapedType>().getShape(); 2036 })); 2037 2038 // Given a `linearIndex`, remap it to a linear index to access linalg op 2039 // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, 2040 // `srcLinearIndices`, `dstLinearIndex` in place. 2041 auto computeRemappedLinearIndex = [&](int linearIndex) { 2042 int totalCount = linearIndex; 2043 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 2044 indices[dim] = totalCount % loopBounds[dim]; 2045 totalCount /= loopBounds[dim]; 2046 } 2047 2048 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 2049 for (int i = 0; i < numInputs; ++i) 2050 srcIndices[i][dim] = indices[inputDims[i][dim]]; 2051 dstIndices[dim] = indices[outputDims[dim]]; 2052 } 2053 2054 dstLinearIndex = dstIndices.front(); 2055 for (int i = 0; i < numInputs; ++i) 2056 srcLinearIndices[i] = srcIndices[i].front(); 2057 2058 for (int dim = 1; dim < outputType.getRank(); ++dim) { 2059 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; 2060 for (int i = 0; i < numInputs; ++i) 2061 srcLinearIndices[i] = 2062 srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; 2063 } 2064 }; 2065 2066 bool isFloat = elementType.isa<FloatType>(); 2067 if (isFloat) { 2068 SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges; 2069 for (int i = 0; i < numInputs; ++i) 2070 inFpRanges.push_back(inputValues[i].getValues<APFloat>()); 2071 2072 computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); 2073 2074 // Transpose the input constant. Because we don't know its rank in 2075 // advance, we need to loop over the range [0, element count) and 2076 // delinearize the index. 2077 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 2078 computeRemappedLinearIndex(linearIndex); 2079 2080 // Collect constant elements for all inputs at this loop iteration. 2081 for (int i = 0; i < numInputs; ++i) 2082 computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; 2083 2084 // Invoke the computation to get the corresponding constant output 2085 // element. 2086 fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; 2087 } 2088 } else { 2089 SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges; 2090 for (int i = 0; i < numInputs; ++i) 2091 inIntRanges.push_back(inputValues[i].getValues<APInt>()); 2092 2093 computeFnInputs.apInts.resize(numInputs); 2094 2095 // Transpose the input constant. Because we don't know its rank in 2096 // advance, we need to loop over the range [0, element count) and 2097 // delinearize the index. 2098 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 2099 computeRemappedLinearIndex(linearIndex); 2100 2101 // Collect constant elements for all inputs at this loop iteration. 2102 for (int i = 0; i < numInputs; ++i) 2103 computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; 2104 2105 // Invoke the computation to get the corresponding constant output 2106 // element. 2107 intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; 2108 } 2109 } 2110 2111 DenseElementsAttr outputAttr = 2112 isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) 2113 : DenseElementsAttr::get(outputType, intOutputValues); 2114 2115 rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr); 2116 return success(); 2117 } 2118 2119 private: 2120 ControlElementwiseOpsFusionFn controlFn; 2121 }; 2122 2123 // Folds linalg.generic ops that are actually transposes on constant values. 2124 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> { 2125 using FoldConstantBase::FoldConstantBase; 2126 2127 bool matchIndexingMaps(GenericOp genericOp) const { 2128 // We should have one input and one output. 2129 return genericOp.getIndexingMaps().size() == 2; 2130 } 2131 2132 RegionComputationFn getRegionComputeFn(GenericOp genericOp) const { 2133 // Make sure the region only contains a yield op. 2134 Block &body = genericOp.region().front(); 2135 if (!llvm::hasSingleElement(body)) 2136 return nullptr; 2137 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); 2138 if (!yieldOp) 2139 return nullptr; 2140 2141 // The yield op should return the block argument corresponds to the input. 2142 for (Value yieldVal : yieldOp.values()) { 2143 auto yieldArg = yieldVal.dyn_cast<BlockArgument>(); 2144 if (!yieldArg || yieldArg.getOwner() != &body) 2145 return nullptr; 2146 if (yieldArg.getArgNumber() != 0) 2147 return nullptr; 2148 } 2149 2150 // No computation; just return the orginal value. 2151 return [](const APIntOrFloatArray &inputs) { 2152 if (inputs.apFloats.empty()) 2153 return APIntOrFloat{inputs.apInts.front(), llvm::None}; 2154 return APIntOrFloat{llvm::None, inputs.apFloats.front()}; 2155 }; 2156 } 2157 2158 ControlElementwiseOpsFusionFn controlFn; 2159 }; 2160 2161 } // namespace 2162 2163 //===---------------------------------------------------------------------===// 2164 // Miscellaneous patterns that help fusion. 2165 //===---------------------------------------------------------------------===// 2166 2167 namespace { 2168 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if 2169 /// the value of the `outs` operand is not used within the op. This is only 2170 /// implemented for `linalg.generic` operations for now, but should hold for all 2171 /// linalg structured ops. 2172 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> { 2173 using OpRewritePattern<GenericOp>::OpRewritePattern; 2174 2175 LogicalResult matchAndRewrite(GenericOp op, 2176 PatternRewriter &rewriter) const override { 2177 rewriter.startRootUpdate(op); 2178 bool modifiedOutput = false; 2179 Location loc = op.getLoc(); 2180 for (OpOperand *opOperand : op.getOutputOperands()) { 2181 if (!op.payloadUsesValueFromOperand(opOperand)) { 2182 Value operandVal = opOperand->get(); 2183 auto operandType = operandVal.getType().dyn_cast<RankedTensorType>(); 2184 if (!operandType) 2185 continue; 2186 2187 // If outs is already an `init_tensor` operation, nothing to do. 2188 auto definingOp = operandVal.getDefiningOp<InitTensorOp>(); 2189 if (definingOp) 2190 continue; 2191 modifiedOutput = true; 2192 SmallVector<Value> dynamicDims; 2193 for (const auto &dim : llvm::enumerate(operandType.getShape())) { 2194 if (dim.value() != ShapedType::kDynamicSize) 2195 continue; 2196 dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>( 2197 loc, operandVal, dim.index())); 2198 } 2199 Value initTensor = rewriter.create<InitTensorOp>( 2200 loc, dynamicDims, operandType.getShape(), 2201 operandType.getElementType()); 2202 op->setOperand(opOperand->getOperandNumber(), initTensor); 2203 } 2204 } 2205 if (!modifiedOutput) { 2206 rewriter.cancelRootUpdate(op); 2207 return failure(); 2208 } 2209 rewriter.finalizeRootUpdate(op); 2210 return success(); 2211 } 2212 }; 2213 } // namespace 2214 2215 //===---------------------------------------------------------------------===// 2216 // Methods that add patterns descrined in this file to a pattern list. 2217 //===---------------------------------------------------------------------===// 2218 2219 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( 2220 RewritePatternSet &patterns) { 2221 patterns 2222 .add<FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>, 2223 FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>, 2224 FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>, 2225 FoldConsumerReshapeOpByLinearization<false, tensor::ExpandShapeOp>>( 2226 patterns.getContext()); 2227 } 2228 2229 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( 2230 RewritePatternSet &patterns) { 2231 patterns 2232 .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>, 2233 FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>, 2234 FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>, 2235 FoldConsumerReshapeOpByLinearization<true, tensor::ExpandShapeOp>>( 2236 patterns.getContext()); 2237 } 2238 2239 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( 2240 RewritePatternSet &patterns, 2241 const ControlElementwiseOpsFusionFn &controlFoldingReshapes) { 2242 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(), 2243 controlFoldingReshapes); 2244 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), 2245 controlFoldingReshapes); 2246 } 2247 2248 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( 2249 RewritePatternSet &patterns, 2250 const ControlElementwiseOpsFusionFn &controlFoldingReshapes) { 2251 patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(), 2252 controlFoldingReshapes); 2253 } 2254 2255 void mlir::linalg::populateElementwiseOpsFusionPatterns( 2256 RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { 2257 auto *context = patterns.getContext(); 2258 patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant, 2259 FoldConstantTranspose>(context, 2260 options.controlElementwiseOpsFusionFn); 2261 patterns.add<RemoveOutsDependency>(context); 2262 populateFoldReshapeOpsByExpansionPatterns(patterns, 2263 options.controlFoldingReshapesFn); 2264 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 2265 GenericOp::getCanonicalizationPatterns(patterns, context); 2266 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); 2267 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); 2268 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns( 2269 patterns); 2270 } 2271 2272 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { 2273 auto *context = patterns.getContext(); 2274 patterns.add<PushExpandingReshape>(context); 2275 } 2276 2277 //===---------------------------------------------------------------------===// 2278 // Passes 2279 //===---------------------------------------------------------------------===// 2280 2281 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, 2282 OpOperand &consumer) { 2283 if (auto producerCollapseOp = 2284 dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) { 2285 return !isUnitDimExpansionOnly(producerCollapseOp); 2286 } 2287 if (auto consumerExpandOp = 2288 dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) { 2289 return !isUnitDimExpansionOnly(consumerExpandOp); 2290 } 2291 return true; 2292 } 2293 2294 namespace { 2295 2296 /// Pass that fuses generic ops on tensors. Used only for testing. 2297 struct LinalgElementwiseOpFusionPass 2298 : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> { 2299 void runOnOperation() override { 2300 Operation *op = getOperation(); 2301 RewritePatternSet patterns(op->getContext()); 2302 ControlElementwiseOpsFusionFn allowFoldingFn = 2303 [](const OpResult &producer, const OpOperand &consumer) { 2304 return true; 2305 }; 2306 populateElementwiseOpsFusionPatterns( 2307 patterns, 2308 LinalgElementwiseFusionOptions().setControlFoldingReshapes( 2309 allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape)); 2310 2311 // Use TopDownTraversal for compile time reasons 2312 GreedyRewriteConfig grc; 2313 grc.useTopDownTraversal = true; 2314 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), 2315 grc); 2316 } 2317 }; 2318 2319 /// Pass to test folding of reshape ops with generic ops by linearization. 2320 struct FoldReshapeOpsByLinearizationPass 2321 : public LinalgFoldReshapeOpsByLinearizationBase< 2322 FoldReshapeOpsByLinearizationPass> { 2323 void runOnOperation() override { 2324 Operation *op = getOperation(); 2325 RewritePatternSet patterns(op->getContext()); 2326 populateFoldReshapeOpsByLinearizationPatterns(patterns); 2327 if (allowFoldingUnitDimReshapes) { 2328 populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns); 2329 } 2330 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); 2331 } 2332 }; 2333 2334 } // namespace 2335 2336 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() { 2337 return std::make_unique<LinalgElementwiseOpFusionPass>(); 2338 } 2339 2340 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() { 2341 return std::make_unique<FoldReshapeOpsByLinearizationPass>(); 2342 } 2343