1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion on tensors operations pass. 10 // 11 //===----------------------------------------------------------------------===// 12 #include <utility> 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 28 using namespace mlir; 29 using namespace mlir::linalg; 30 31 //===---------------------------------------------------------------------===// 32 // Methods and patterns that fuse elementwise `linalg.generic` operations. 33 //===---------------------------------------------------------------------===// 34 35 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 36 /// the `producer` to use in the fused operation given the indexing map of the 37 /// result of the producer in the consumer. 38 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 39 OpOperand *producerOpOperand, AffineMap producerResultIndexMap, 40 AffineMap fusedConsumerArgIndexMap) { 41 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 42 // from consumer loop -> consumer arg tensor index/producer result tensor 43 // index. The fused loop is same as the consumer loop. For each producer arg 44 // the indexing map to be computed is a map from consumer loop -> producer 45 // arg tensor index. 46 // producerResultIndexMap is a map from producer loop -> tensor index. 47 // Compute the inverse to get map from tensor index -> producer loop. 48 // The inverse is a map from producer result tensor index -> producer loop. 49 AffineMap invProducerResultIndexMap = 50 inversePermutation(producerResultIndexMap); 51 assert(invProducerResultIndexMap && 52 "expected producer result indexing map to be invertible"); 53 54 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner()); 55 // argMap is a map from producer loop -> producer arg tensor index. 56 AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); 57 58 // Compose argMap with invProducerResultIndexMap to get a map from 59 // producer result tensor index -> producer arg tensor index. 60 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 61 62 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 63 // consumer loop/ fused loop -> producer arg tensor index. 64 return t1.compose(fusedConsumerArgIndexMap); 65 } 66 67 /// Conditions for elementwise fusion of generic operations. 68 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, 69 OpOperand *consumerOpOperand) { 70 // Producer and consumer must have tensor semantics. 71 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 72 return false; 73 74 // Verify that 75 // - the producer has all "parallel" iterator type. 76 if (producer.getNumParallelLoops() != producer.getNumLoops()) 77 return false; 78 79 // Only allow fusing the producer of an input operand for now. 80 // TODO: allow fusing the producer of an output operand. 81 if (!consumer.isInputTensor(consumerOpOperand)) 82 return false; 83 84 // Get the consumer index map. The number of results of the consumer index 85 // map must match the number of loops of the producer. 86 AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); 87 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 88 return false; 89 90 // Currently support only operations with single result. 91 if (producer.getNumOutputs() != 1) 92 return false; 93 94 // Finally the index_map for the result must be invertible. For now just 95 // verify it is a permutation. 96 AffineMap producerResultIndexMap = 97 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 98 if (!producerResultIndexMap.isPermutation()) 99 return false; 100 101 // Ensure that the fusion does not remove size information required to 102 // get the loop bounds. For non-reduction generics, this is trivially the 103 // case due to the output operand. For reductions, we need to check that after 104 // the fusion, each loop dimension has at least one input that defines it. 105 if ((consumer.getNumReductionLoops())) { 106 BitVector coveredDims(consumer.getNumLoops(), false); 107 108 auto addToCoveredDims = [&](AffineMap map) { 109 for (auto result : map.getResults()) 110 if (auto dimExpr = result.dyn_cast<AffineDimExpr>()) 111 coveredDims[dimExpr.getPosition()] = true; 112 }; 113 114 for (auto pair : 115 llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) { 116 Value operand = std::get<0>(pair); 117 if (operand == consumerOpOperand->get()) 118 continue; 119 AffineMap operandMap = std::get<1>(pair); 120 addToCoveredDims(operandMap); 121 } 122 123 for (OpOperand *operand : producer.getInputOperands()) { 124 AffineMap newIndexingMap = 125 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 126 operand, producerResultIndexMap, consumerIndexMap); 127 addToCoveredDims(newIndexingMap); 128 } 129 if (!coveredDims.all()) 130 return false; 131 } 132 133 return true; 134 } 135 136 /// Generate the region of the fused tensor operation. The region of the fused 137 /// op must be empty. 138 static void 139 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, 140 AffineMap consumerToProducerLoopsMap, 141 OpOperand *consumerOpOperand, 142 unsigned nloops) { 143 auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp()); 144 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 145 // Build the region of the fused op. 146 Block &producerBlock = producer->getRegion(0).front(); 147 Block &consumerBlock = consumer->getRegion(0).front(); 148 Block *fusedBlock = new Block(); 149 fusedOp.region().push_back(fusedBlock); 150 BlockAndValueMapping mapper; 151 OpBuilder::InsertionGuard guard(rewriter); 152 rewriter.setInsertionPointToStart(fusedBlock); 153 154 // 2. Add an index operation for every fused loop dimension and use the 155 // `consumerToProducerLoopsMap` to map the producer indices. 156 if (producer.hasIndexSemantics()) { 157 // Add an index operation for every fused loop dimension. 158 unsigned numFusedOpLoops = 159 std::max(producer.getNumLoops(), consumer.getNumLoops()); 160 SmallVector<Value> fusedIndices; 161 fusedIndices.reserve(numFusedOpLoops); 162 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops), 163 std::back_inserter(fusedIndices), [&](uint64_t dim) { 164 return rewriter.create<IndexOp>(producer.getLoc(), dim); 165 }); 166 for (IndexOp indexOp : 167 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) { 168 Value newIndex = rewriter.create<mlir::AffineApplyOp>( 169 producer.getLoc(), 170 consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices); 171 mapper.map(indexOp.getResult(), newIndex); 172 } 173 } 174 // TODO: allow fusing the producer of an output operand. 175 assert(consumer.isInputTensor(consumerOpOperand) && 176 "expected producer of input operand"); 177 // 3. Consumer input operands up to consumerIdx (exclusive). 178 for (BlockArgument bbArg : consumerBlock.getArguments().take_front( 179 consumerOpOperand->getOperandNumber())) // input assumption. 180 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 181 182 // Replacing consumerIdx requires getting the cloned, yielded, value from 183 // the (cloned) producer block. This happens in step 9. 184 185 // 4. Splice in producer's input operands. 186 for (BlockArgument bbArg : 187 producerBlock.getArguments().take_front(producer.getNumInputs())) 188 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 189 190 // 4.b. Producer output operand/map that is fused needs to be mapped to the 191 // producer bbArg if it is an "initTensor" (i.e. its value is actually read). 192 assert(producer->getNumResults() == 1 && "expected single result producer"); 193 if (producer.isInitTensor(producer.getOutputOperand(0))) { 194 BlockArgument bbArg = producerBlock.getArguments() 195 .drop_front(producer.getNumInputs()) 196 // TODO: bbArg index of 197 .front(); 198 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 199 } 200 // 5. Remaining consumer's input operands (drop past index `consumerIdx`). 201 for (BlockArgument bbArg : 202 consumerBlock.getArguments() 203 .take_front(consumer.getNumInputs()) 204 .drop_front(consumerOpOperand->getOperandNumber() + 1)) 205 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 206 // 6. All of consumer's output operands. 207 for (BlockArgument bbArg : 208 consumerBlock.getArguments().take_back(consumer.getNumOutputs())) 209 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 210 // 7. All of producer's output operands except the one fused. 211 // TODO: allow fusion of multi-result producers. 212 assert(producer->getNumResults() == 1 && "expected single result producer"); 213 214 // 8. Clone all producer operations except for the yield and index operations 215 // to the fused operation. 216 for (auto &op : producerBlock.without_terminator()) { 217 if (!isa<IndexOp>(op)) 218 rewriter.clone(op, mapper); 219 } 220 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just 221 // forward the yield operand. 222 auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator()); 223 // TODO: allow fusion of multi-result producers. 224 assert(producer->getNumResults() == 1 && "expected single result producer"); 225 unsigned producerResultNumber = 0; 226 Value replacement = 227 mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); 228 // Sanity checks, if replacement is not already in the mapper then it must be 229 // produced outside. 230 if (replacement == yieldOp.getOperand(producerResultNumber)) { 231 if (auto bb = replacement.dyn_cast<BlockArgument>()) 232 assert(bb.getOwner() != &producerBlock && 233 "yielded block argument must have been mapped"); 234 else 235 assert(!producer->isAncestor(replacement.getDefiningOp()) && 236 "yielded value must have been mapped"); 237 } 238 mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), 239 replacement); 240 // 10. Clone operations from the consumer to the fused op. 241 for (auto &op : consumerBlock.getOperations()) 242 rewriter.clone(op, mapper); 243 244 // Sanity checks. 245 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && 246 "Ill-formed GenericOp region"); 247 } 248 249 static Optional<SmallVector<Value>> 250 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, 251 const ControlElementwiseOpsFusionFn &controlFn, 252 PatternRewriter &rewriter) { 253 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 254 if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || 255 !controlFn(producer->getResult(0), *consumerOpOperand)) 256 return llvm::None; 257 258 // TODO: allow fusing the producer of an output operand. 259 assert(consumer.isInputTensor(consumerOpOperand) && 260 "expected producer of input operand"); 261 262 // Compute the fused operands list and indexing maps. 263 SmallVector<Value> fusedOperands; 264 SmallVector<AffineMap> fusedIndexMaps; 265 fusedOperands.reserve(producer->getNumOperands() + 266 consumer->getNumOperands()); 267 fusedIndexMaps.reserve(producer->getNumOperands() + 268 consumer->getNumOperands()); 269 // In the following, numbering matches that of `generateFusedTensorOpRegion`. 270 // 3. Consumer input operands/maps up to consumerIdx (exclusive). 271 SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands(); 272 SmallVector<OpOperand *>::iterator it = 273 llvm::find(consumerInputs, consumerOpOperand); 274 assert(it != consumerInputs.end() && "expected to find the consumer operand"); 275 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { 276 fusedOperands.push_back(opOperand->get()); 277 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 278 } 279 // 4. Splice in producer's input operands/maps. 280 assert(producer->getNumResults() == 1 && "expected single result producer"); 281 AffineMap producerResultIndexMap = 282 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 283 for (OpOperand *opOperand : producer.getInputOperands()) { 284 fusedOperands.push_back(opOperand->get()); 285 // Compute indexing maps for the producer args in the fused operation. 286 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 287 opOperand, producerResultIndexMap, 288 consumer.getTiedIndexingMap(consumerOpOperand)); 289 fusedIndexMaps.push_back(map); 290 } 291 // 4.b. Producer output operand/map that is fused needs to be passed if it is 292 // an "initTensor" (i.e. its value is actually read). 293 assert(producer->getNumResults() == 1 && "expected single result producer"); 294 if (producer.isInitTensor(producer.getOutputOperand(0))) { 295 fusedOperands.push_back(producer.getOutputOperand(0)->get()); 296 // Compute indexing maps for the producer args in the fused operation. 297 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 298 producer.getOutputOperand(0), producerResultIndexMap, 299 consumer.getTiedIndexingMap(consumerOpOperand)); 300 fusedIndexMaps.push_back(map); 301 } 302 // 5. Remaining consumer's input operands/maps (drop past index 303 // `consumerIdx`). 304 for (OpOperand *opOperand : 305 llvm::make_range(std::next(it), consumerInputs.end())) { 306 fusedOperands.push_back(opOperand->get()); 307 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 308 } 309 // 6. All of consumer's output operands (skip operands: added by the builder). 310 for (OpOperand *opOperand : consumer.getOutputOperands()) 311 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 312 // 7. All of producer's output operands/maps except the one fused. 313 // TODO: allow fusion of multi-result producers. 314 assert(producer->getNumResults() == 1 && "expected single result producer"); 315 316 // Generate the fused op. 317 SmallVector<Value> consumerOutputs = consumer.getOutputOperands(); 318 auto fusedOp = rewriter.create<GenericOp>( 319 consumer.getLoc(), consumer->getResultTypes(), 320 /*inputs=*/fusedOperands, 321 // TODO: handle outputs. 322 consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 323 consumer.iterator_types(), 324 /*doc=*/nullptr, 325 /*library_call=*/nullptr); 326 if (!fusedOp.getShapesToLoopsMap()) { 327 // Fused op has invalid indexing maps. Typically this means something is off 328 // in the input, but going ahead here would result in verification errors. 329 // So cleanup and abort. 330 rewriter.eraseOp(fusedOp); 331 return llvm::None; 332 } 333 334 // Construct an AffineMap from consumer loops to producer loops. 335 // consumer loop -> tensor index 336 AffineMap consumerResultIndexMap = 337 consumer.getTiedIndexingMap(consumerOpOperand); 338 // tensor index -> producer loop 339 AffineMap invProducerResultIndexMap = 340 inversePermutation(producerResultIndexMap); 341 assert(invProducerResultIndexMap && 342 "expected producer result indexig map to be invertible"); 343 // consumer loop -> producer loop 344 AffineMap consumerToProducerLoopsMap = 345 invProducerResultIndexMap.compose(consumerResultIndexMap); 346 347 generateFusedElementwiseOpRegion(rewriter, fusedOp, 348 consumerToProducerLoopsMap, 349 consumerOpOperand, consumer.getNumLoops()); 350 return SmallVector<Value>(fusedOp->getResults()); 351 } 352 353 static Optional<SmallVector<Value>> 354 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, 355 GenericOp producer, 356 const ControlElementwiseOpsFusionFn &controlFn) { 357 if (producer->getNumResults() != 1) 358 return llvm::None; 359 360 return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, 361 rewriter); 362 } 363 364 namespace { 365 /// Patterns to fuse a generic op, with the producer of its operands. 366 class FuseElementwiseOps : public OpRewritePattern<GenericOp> { 367 public: 368 FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, 369 PatternBenefit benefit = 1) 370 : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {} 371 372 LogicalResult matchAndRewrite(GenericOp genericOp, 373 PatternRewriter &rewriter) const override { 374 // Find the first operand that is defined by another generic op on tensors. 375 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 376 auto producer = 377 dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp()); 378 if (!producer || !producer.hasTensorSemantics()) 379 continue; 380 Optional<SmallVector<Value>> fusedOpResults = 381 fuseElementwiseOps(rewriter, opOperand, producer, controlFn); 382 if (fusedOpResults) { 383 rewriter.replaceOp(genericOp, *fusedOpResults); 384 return success(); 385 } 386 } 387 return failure(); 388 } 389 390 private: 391 ControlElementwiseOpsFusionFn controlFn; 392 }; 393 } // namespace 394 395 //===---------------------------------------------------------------------===// 396 // Methods and patterns that fuse reshape ops with elementwise operations by 397 // linearization of indexing maps. 398 //===---------------------------------------------------------------------===// 399 400 // TODO(ravishankarm): The indexing maps 401 // these produce in the general case are detrimental to transformations. 402 // These patterns are on deprecation path in favor of using fusion by 403 // collapsing, which covers the only legitimate use case of this pattern of 404 // folding unit-extent dims. 405 406 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` 407 /// provided, given the shape of the source tensor that corresponds to the 408 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions 409 /// are "row-major" ordered logically. 410 /// 411 /// For example: 412 /// 413 /// %0 = op ... : tensor<?x?x4x5xf32> 414 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` 415 /// 416 /// and reshape: 417 /// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] : 418 /// tensor<?x?x4x5xf32> into tensor<?x?xf32> 419 /// 420 /// would be rewritten into: 421 /// %0 = op ... : tensor<?x?x4x5xf32> 422 /// with output index_map 423 /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` 424 template <typename TensorReshapeOp> 425 static AffineMap linearizeCollapsedDims(AffineMap sourceMap, 426 TensorReshapeOp reshapeOp) { 427 constexpr bool isExpanding = 428 std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value; 429 ArrayRef<int64_t> sourceShape = 430 (isExpanding ? reshapeOp.getResultType().getShape() 431 : reshapeOp.getSrcType().getShape()); 432 SmallVector<AffineExpr> resultExprs; 433 ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults(); 434 MLIRContext *context = sourceMap.getContext(); 435 436 // Compute the result exprs based on the reassociation maps. 437 for (auto &indices : reshapeOp.getReassociationIndices()) { 438 // Assume that they are in-order and contiguous (already checked in 439 // verifier). 440 assert(!indices.empty()); 441 SmallVector<int64_t> sizes; 442 SmallVector<AffineExpr> dimExprs; 443 for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()), 444 sourceExprs.slice(indices[0], indices.size()))) { 445 if (std::get<0>(en) == 1) 446 continue; 447 sizes.push_back(std::get<0>(en)); 448 dimExprs.push_back(std::get<1>(en)); 449 } 450 AffineExpr linearizedExpr = 451 makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); 452 resultExprs.push_back(linearizedExpr); 453 } 454 // The new affine map cannot drop unused dimension but some new symbols may 455 // have been added. Create a map with at least as many dimensions/symbols as 456 // the original affine map. 457 int64_t maxDim = -1; 458 int64_t maxSym = -1; 459 getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym); 460 unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims()); 461 unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols()); 462 return AffineMap::get(numDims, numSyms, resultExprs, context); 463 } 464 465 // tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a 466 // producer). Fusing when operand has higher rank will require use of mods and 467 // divs in the indexing maps of the fused op which would make it non-invertible. 468 static bool isTensorReshapeOpFoldableByLinearization( 469 tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) { 470 if (!asProducer) 471 return false; 472 return useIndexMap.isPermutation(); 473 } 474 475 // tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a 476 // consumer). 477 static bool 478 isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp, 479 AffineMap useIndexMap, 480 bool asProducer) { 481 if (asProducer) 482 return false; 483 return useIndexMap.isPermutation(); 484 } 485 486 /// Check if the reshape operation is only expansion into/collapsing of 487 /// unit-dimension. 488 template <typename TensorReshapeOp> 489 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) { 490 constexpr bool isExpanding = 491 std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value; 492 ArrayRef<int64_t> expandedShape = 493 (isExpanding ? reshapeOp.getResultType().getShape() 494 : reshapeOp.getSrcType().getShape()); 495 for (auto &indices : reshapeOp.getReassociationIndices()) { 496 unsigned numUnitDims = 0; 497 for (int64_t position : indices) 498 if (expandedShape[position] == 1) 499 numUnitDims++; 500 if (numUnitDims != indices.size() - 1) 501 return false; 502 } 503 return true; 504 } 505 506 namespace { 507 /// Pattern to fold tensor_expand_shape op with its consumer by using the source 508 /// of the reshape op as the operand in the consumer (instead of the result of 509 /// the tensor_collapse_shape). The corresponding index map in the consumer 510 /// needs to be modified to linearize the folded dimension. 511 /// 512 /// For example, 513 /// 514 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 515 /// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] 516 /// tensor<?x?x?xf32> into tensor<?x?x4x?xf32> 517 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } 518 /// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ... 519 /// -> tensor<?x?x4x?xf32> 520 /// 521 /// can be folded into 522 /// 523 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> 524 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 525 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } 526 /// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ... 527 /// -> tensor<?x?x4x?xf32> 528 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 529 struct FoldProducerReshapeOpByLinearization 530 : public OpRewritePattern<GenericOp> { 531 using OpRewritePattern<GenericOp>::OpRewritePattern; 532 533 LogicalResult matchAndRewrite(GenericOp genericOp, 534 PatternRewriter &rewriter) const override { 535 if (!genericOp.hasTensorSemantics()) 536 return failure(); 537 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 538 for (const auto &en : llvm::enumerate(inputOperands)) { 539 auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>(); 540 if (!reshapeOp) 541 continue; 542 543 if (!isTensorReshapeOpFoldableByLinearization( 544 reshapeOp, genericOp.getTiedIndexingMap(en.value()), 545 /*asProducer =*/true) || 546 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 547 continue; 548 549 // Compute the fused operands list, 550 SmallVector<Value> fusedOperands = genericOp.getInputOperands(); 551 fusedOperands[en.index()] = reshapeOp.src(); 552 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 553 llvm::append_range(fusedOperands, outputOperands); 554 555 // Compute indexing_maps for the fused operation. The indexing_maps for 556 // the operands of the consumers that arent fused are the same. 557 SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps(); 558 559 // Compute the indexing map to use for the result of the producer. 560 AffineMap modifiedMap = 561 linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp); 562 // The modified map cannot have symbols. 563 if (modifiedMap.getNumSymbols()) 564 return failure(); 565 for (AffineExpr expr : modifiedMap.getResults()) { 566 if (!expr.isPureAffine()) 567 return failure(); 568 } 569 fusedIndexMaps[en.index()] = modifiedMap; 570 571 // Further check that the resulting index maps can be fused and 572 // inverted. Without this the resultant op is not legal. 573 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 574 return rewriter.notifyMatchFailure( 575 genericOp, "fused op loop bound computation failed"); 576 } 577 578 rewriter.startRootUpdate(genericOp); 579 genericOp->setOperands(fusedOperands); 580 genericOp.indexing_mapsAttr( 581 rewriter.getAffineMapArrayAttr(fusedIndexMaps)); 582 rewriter.finalizeRootUpdate(genericOp); 583 return success(); 584 } 585 return failure(); 586 } 587 }; 588 589 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its 590 /// producer. The corresponding index map in the consumer needs to be modified 591 /// to linearize the folded dimension. 592 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 593 struct FoldConsumerReshapeOpByLinearization 594 : public OpRewritePattern<TensorReshapeOp> { 595 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 596 597 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 598 PatternRewriter &rewriter) const override { 599 GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>(); 600 if (!producer || !producer.hasTensorSemantics() || 601 producer.getNumOutputs() != 1 || 602 !isTensorReshapeOpFoldableByLinearization( 603 reshapeOp, 604 producer.getTiedIndexingMap(producer.getOutputOperand(0)), 605 /*asProducer =*/false) || 606 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 607 return failure(); 608 // The indexing_maps for the operands of the fused operation are same as 609 // those for the operands of the producer. 610 SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps(); 611 612 // Compute the indexing map to use for the operand of the producer. 613 AffineMap modifiedMap = linearizeCollapsedDims( 614 producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp); 615 for (AffineExpr expr : modifiedMap.getResults()) { 616 if (!expr.isPureAffine()) { 617 return rewriter.notifyMatchFailure( 618 producer, "fused op indexing map is not affine"); 619 } 620 } 621 fusedIndexMaps.back() = modifiedMap; 622 623 // Further check that the resulting index maps can be fused and 624 // inverted. Without this the resultant op is not legal. 625 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 626 return rewriter.notifyMatchFailure( 627 producer, "fused op loop bound computation failed"); 628 } 629 630 Location loc = producer.getLoc(); 631 SmallVector<Value> inputOperands = producer.getInputOperands(); 632 Value output = rewriter.create<TensorReshapeOp>( 633 loc, producer.getOutputOperand(0)->get(), 634 reshapeOp.getReassociationExprs()); 635 auto fusedOp = rewriter.create<GenericOp>( 636 loc, reshapeOp.getResultType(), 637 /*inputs=*/inputOperands, 638 // TODO: handle outputs. 639 /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 640 producer.iterator_types(), 641 /*doc=*/nullptr, 642 /*library_call=*/nullptr); 643 auto &fusedRegion = fusedOp->getRegion(0); 644 rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion, 645 fusedRegion.begin()); 646 rewriter.replaceOp(reshapeOp, fusedOp->getResults()); 647 return success(); 648 } 649 }; 650 } // namespace 651 652 //===---------------------------------------------------------------------===// 653 // Methods and patterns that fuse reshape ops with elementwise operations by 654 // expanding the dimensionality of the elementwise operations. 655 //===---------------------------------------------------------------------===// 656 657 /// Conditions for folding a generic operation with a reshape op by expanding 658 /// the iteration space dimensionality for tensor operations. These are 659 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the 660 /// following fusion pattern. 661 /// 662 /// Consider 663 /// 664 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>) 665 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 666 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 667 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] 668 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] 669 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 670 /// 671 /// The reshape can be folded into the `genericOp` if its loop dimensionality 672 /// is increased to match the result (operand) of the tensor_expand_shape. 673 /// The indexing_map of the fused tensor in the `genericOp` and the 674 /// reassociation map helps compute the indexing maps of the modified op. 675 /// For the above example, based on the reassociation map it 676 /// can be concluded that 677 /// 678 /// - The loop used to access the first dimension of the fused tensor is split 679 /// into two. 680 /// - The loop used to access the second dimension of the fused tensor is kept 681 /// as is. 682 /// - The loop used to access the third dimension of the fused tensor is split 683 /// into three. 684 /// 685 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified 686 /// op, then 687 /// 688 /// d0 -> e0, e1 689 /// d1 -> e2, e3, e4 690 /// d2 -> e5 691 /// 692 /// substituting this, the generic op can be rewritten as 693 /// 694 /// %d = linalg.generic ins(%0, %1 : ) 695 /// indexing_maps = 696 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, 697 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, 698 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] 699 /// 700 /// Since operands to the linalg generic are now 5D, reshapes can be introduced 701 /// to make it consistent 702 /// 703 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]] 704 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 705 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]] 706 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 707 /// 708 /// The added reshapes are again expanding patterns, so they will get fused 709 /// with its producers if possible. 710 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, 711 OpOperand *fusableOpOperand) { 712 // Is fusable only if: 713 // - All the indexing maps for operands and results are projected 714 // permutations. 715 // - The fused tensor is not a scalar. 716 // - All the loops are parallel loops. 717 return genericOp.hasTensorSemantics() && 718 llvm::all_of(genericOp.indexing_maps().getValue(), 719 [](Attribute attr) { 720 return attr.cast<AffineMapAttr>() 721 .getValue() 722 .isProjectedPermutation(); 723 }) && 724 genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && 725 llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { 726 return attr.cast<StringAttr>().getValue() == 727 getParallelIteratorTypeName(); 728 }); 729 } 730 731 namespace { 732 /// Information needed to expand a generic operation to fold the reshape with 733 /// it. 734 class ExpansionInfo { 735 public: 736 // Computes the mapping from original dimensions of the op to the dimensions 737 // of the expanded op given the `indexingMap` of the fused operand/result of 738 // the generic op, the `reassocationMaps` of the reshape op and the shape of 739 // the expanded op. 740 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, 741 ArrayRef<AffineMap> reassociationMaps, 742 ArrayRef<int64_t> expandedShape, 743 ArrayRef<int64_t> collapsedShape, 744 PatternRewriter &rewriter); 745 unsigned getOrigOpNumDims() const { return reassociation.size(); } 746 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } 747 ReassociationIndicesRef getExpandedDims(unsigned i) const { 748 return reassociation[i]; 749 } 750 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { 751 return expandedShapeMap[i]; 752 } 753 ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; } 754 755 private: 756 /// Reassociation from the dimensions in the original operation to the 757 /// dimension of the expanded operation. 758 SmallVector<ReassociationIndices> reassociation; 759 /// Mapping from extent of loops in the original operation, to the extent of 760 /// loops in the expanded operation. 761 SmallVector<SmallVector<int64_t>> expandedShapeMap; 762 /// Extent of the loop in the original operation. 763 SmallVector<int64_t> originalLoopExtent; 764 unsigned expandedOpNumDims; 765 }; 766 } // namespace 767 768 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, 769 OpOperand *fusableOpOperand, 770 ArrayRef<AffineMap> reassociationMaps, 771 ArrayRef<int64_t> expandedShape, 772 ArrayRef<int64_t> collapsedShape, 773 PatternRewriter &rewriter) { 774 if (reassociationMaps.empty()) 775 return failure(); 776 AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); 777 778 Optional<SmallVector<int64_t, 4>> originalLoopRange = 779 linalgOp.getStaticLoopRanges(); 780 if (!originalLoopRange) 781 return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range"); 782 originalLoopExtent.assign(originalLoopRange->begin(), 783 originalLoopRange->end()); 784 785 reassociation.clear(); 786 expandedShapeMap.clear(); 787 // Compute the number of dimension in the expanded op that correspond to each 788 // dimension of the original op. 789 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1); 790 expandedShapeMap.resize(fusedIndexMap.getNumDims()); 791 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { 792 unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition(); 793 AffineMap foldedDims = reassociationMaps[resultExpr.index()]; 794 numExpandedDims[pos] = foldedDims.getNumResults(); 795 ArrayRef<int64_t> shape = 796 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); 797 expandedShapeMap[pos].assign(shape.begin(), shape.end()); 798 } 799 // The remaining dimensions remain the same. 800 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims())) 801 if (expandedShapeMap[i].empty()) 802 expandedShapeMap[i] = {originalLoopExtent[i]}; 803 804 // Compute reassociation map from the original op to the expanded op. 805 unsigned sum = 0; 806 reassociation.reserve(fusedIndexMap.getNumDims()); 807 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) { 808 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value()); 809 reassociation.emplace_back(seq.begin(), seq.end()); 810 sum += numFoldedDim.value(); 811 } 812 expandedOpNumDims = sum; 813 return success(); 814 } 815 816 /// Epanding the body of a linalg operation requires adaptations of the accessed 817 /// loop indices. Specifically, access of indices in the original operation need 818 /// to be replaced with linearizations of indices in the expanded op. That 819 /// requires the shape of the expanded dimensions to be static (at least all but 820 /// the most significant). For now check that these are all statically sized. 821 /// Note that this could be extended to handle dynamic case, but the 822 /// implementation below uses `affine.apply` which seems to have issues when the 823 /// shapes are not static. 824 static LogicalResult isGenericOpExpandable(GenericOp genericOp, 825 const ExpansionInfo &expansionInfo, 826 PatternRewriter &rewriter) { 827 if (!genericOp.hasIndexSemantics()) 828 return success(); 829 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 830 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 831 if (expandedShape.size() == 1) 832 continue; 833 for (int64_t shape : expandedShape.drop_front()) { 834 if (ShapedType::isDynamic(shape)) { 835 return rewriter.notifyMatchFailure( 836 genericOp, "cannot expand due to index semantics and dynamic dims"); 837 } 838 } 839 } 840 return success(); 841 } 842 843 /// Return the indexing map to use in the expanded op for a given the 844 /// `indexingMap` of the original operation. 845 static AffineMap 846 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, 847 const ExpansionInfo &expansionInfo) { 848 SmallVector<AffineExpr> newExprs; 849 for (AffineExpr expr : indexingMap.getResults()) { 850 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 851 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>( 852 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { 853 return builder.getAffineDimExpr(static_cast<unsigned>(v)); 854 })); 855 newExprs.append(expandedExprs.begin(), expandedExprs.end()); 856 } 857 return AffineMap::get(expansionInfo.getExpandedOpNumDims(), 858 indexingMap.getNumSymbols(), newExprs, 859 builder.getContext()); 860 } 861 862 /// Return the type of the operand/result to use in the expanded op given the 863 /// type in the original op. 864 static RankedTensorType getExpandedType(RankedTensorType originalType, 865 AffineMap indexingMap, 866 const ExpansionInfo &expansionInfo) { 867 SmallVector<int64_t> expandedShape; 868 for (AffineExpr expr : indexingMap.getResults()) { 869 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 870 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); 871 expandedShape.append(dimExpansion.begin(), dimExpansion.end()); 872 } 873 return RankedTensorType::get(expandedShape, originalType.getElementType()); 874 } 875 876 /// Returns the reassociation maps to use in the `tensor.expand_shape` 877 /// operation to convert the operands of the original operation to operands of 878 /// the expanded operation. The same method is used to compute the 879 /// `tensor.collapse_shape` used to collapse the result of the expanded 880 /// op to get the value that can replace all uses of the results of the original 881 /// op. 882 static SmallVector<ReassociationIndices> 883 getReassociationForExpansion(AffineMap indexingMap, 884 const ExpansionInfo &expansionInfo) { 885 SmallVector<ReassociationIndices> reassociation; 886 unsigned numReshapeDims = 0; 887 for (AffineExpr expr : indexingMap.getResults()) { 888 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 889 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); 890 SmallVector<int64_t, 2> indices = llvm::to_vector<2>( 891 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims)); 892 reassociation.emplace_back(std::move(indices)); 893 numReshapeDims += numExpandedDims; 894 } 895 return reassociation; 896 } 897 898 /// Update the body of an expanded linalg operation having index semantics. The 899 /// indices of the original operation need to be recovered by linearizing the 900 /// indices of the correspoding dimensions of the expanded operation. For now it 901 /// is assumed that the shapes of the expanded operation needed for 902 /// linearization are static. 903 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, 904 Location loc, Region &fusedRegion, 905 const ExpansionInfo &expansionInfo) { 906 // Replace the original indices by the linearization of the expanded indices. 907 for (IndexOp indexOp : 908 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) { 909 ArrayRef<int64_t> expandedDims = 910 expansionInfo.getExpandedDims(indexOp.dim()); 911 assert(!expandedDims.empty() && "expected valid expansion info"); 912 913 // Skip index operations that are not affected by the expansion. 914 if (expandedDims.size() == 1 && 915 expandedDims.front() == (int64_t)indexOp.dim()) 916 continue; 917 918 // Linearize the expanded indices of the original index dimension. 919 OpBuilder::InsertionGuard guard(rewriter); 920 rewriter.setInsertionPointAfter(indexOp); 921 ArrayRef<int64_t> expandedDimsShape = 922 expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front(); 923 SmallVector<Value> expandedIndices; 924 expandedIndices.reserve(expandedDims.size() - 1); 925 llvm::transform( 926 expandedDims.drop_front(), std::back_inserter(expandedIndices), 927 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); 928 Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front()); 929 for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { 930 assert(!ShapedType::isDynamic(std::get<0>(it))); 931 AffineExpr idx, acc; 932 bindDims(rewriter.getContext(), idx, acc); 933 newIndex = rewriter.create<AffineApplyOp>( 934 indexOp.getLoc(), idx + acc * std::get<0>(it), 935 ValueRange{std::get<1>(it), newIndex}); 936 } 937 rewriter.replaceOp(indexOp, newIndex); 938 } 939 } 940 941 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op 942 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes 943 /// that those conditions have been satisfied. 944 static Optional<SmallVector<Value>> 945 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, 946 OpOperand *fusableOpOperand, 947 PatternRewriter &rewriter) { 948 assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) && 949 "preconditions for fuse operation failed"); 950 // Check if reshape is expanding or collapsing. 951 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp); 952 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp); 953 bool isExpanding = (expandingReshapeOp != nullptr); 954 RankedTensorType expandedType = isExpanding 955 ? expandingReshapeOp.getResultType() 956 : collapsingReshapeOp.getSrcType(); 957 RankedTensorType collapsedType = isExpanding 958 ? expandingReshapeOp.getSrcType() 959 : collapsingReshapeOp.getResultType(); 960 961 ExpansionInfo expansionInfo; 962 if (failed(expansionInfo.compute( 963 genericOp, fusableOpOperand, 964 isExpanding ? expandingReshapeOp.getReassociationMaps() 965 : collapsingReshapeOp.getReassociationMaps(), 966 expandedType.getShape(), collapsedType.getShape(), rewriter))) 967 return llvm::None; 968 969 if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) 970 return llvm::None; 971 972 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( 973 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) { 974 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); 975 })); 976 977 SmallVector<Value> expandedOpOperands; 978 expandedOpOperands.reserve(genericOp.getNumInputs()); 979 for (OpOperand *opOperand : genericOp.getInputOperands()) { 980 if (opOperand == fusableOpOperand) { 981 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() 982 : collapsingReshapeOp.src()); 983 continue; 984 } 985 if (genericOp.isInputTensor(opOperand)) { 986 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 987 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 988 RankedTensorType expandedOperandType = 989 getExpandedType(opOperandType, indexingMap, expansionInfo); 990 if (expandedOperandType != opOperand->get().getType()) { 991 // Reshape the operand to get the right type. 992 SmallVector<ReassociationIndices> reassociation = 993 getReassociationForExpansion(indexingMap, expansionInfo); 994 if (failed(reshapeLikeShapesAreCompatible( 995 [&](const Twine &msg) { 996 return rewriter.notifyMatchFailure(genericOp, msg); 997 }, 998 opOperandType.getShape(), expandedOperandType.getShape(), 999 reassociation, 1000 /*isExpandingReshape=*/true))) 1001 return llvm::None; 1002 expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>( 1003 genericOp.getLoc(), expandedOperandType, opOperand->get(), 1004 reassociation)); 1005 continue; 1006 } 1007 } 1008 expandedOpOperands.push_back(opOperand->get()); 1009 } 1010 1011 Location loc = genericOp.getLoc(); 1012 SmallVector<Value> outputs; 1013 for (OpOperand *opOperand : genericOp.getOutputOperands()) { 1014 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 1015 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 1016 RankedTensorType expandedOutputType = 1017 getExpandedType(opOperandType, indexingMap, expansionInfo); 1018 if (expandedOutputType != opOperand->get().getType()) { 1019 SmallVector<ReassociationIndices> reassociation = 1020 getReassociationForExpansion(indexingMap, expansionInfo); 1021 if (failed(reshapeLikeShapesAreCompatible( 1022 [&](const Twine &msg) { 1023 return rewriter.notifyMatchFailure(genericOp, msg); 1024 }, 1025 opOperandType.getShape(), expandedOutputType.getShape(), 1026 reassociation, 1027 /*isExpandingReshape=*/true))) 1028 return llvm::None; 1029 outputs.push_back(rewriter.create<tensor::ExpandShapeOp>( 1030 genericOp.getLoc(), expandedOutputType, opOperand->get(), 1031 reassociation)); 1032 } 1033 } 1034 1035 // The iterator types of the expanded op are all parallel. 1036 SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(), 1037 getParallelIteratorTypeName()); 1038 1039 TypeRange resultTypes = ValueRange(outputs).getTypes(); 1040 auto fusedOp = 1041 rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes, 1042 /*inputs=*/expandedOpOperands, outputs, 1043 expandedOpIndexingMaps, iteratorTypes); 1044 Region &fusedRegion = fusedOp->getRegion(0); 1045 Region &originalRegion = genericOp->getRegion(0); 1046 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); 1047 1048 // Update the index accesses after the expansion. 1049 updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); 1050 1051 // Reshape the result values to their original shape if this is a collapsing 1052 // reshape folded into its consumer. 1053 SmallVector<Value> resultVals; 1054 for (OpResult opResult : genericOp->getOpResults()) { 1055 int64_t resultNumber = opResult.getResultNumber(); 1056 if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { 1057 SmallVector<ReassociationIndices> reassociation = 1058 getReassociationForExpansion( 1059 genericOp.getTiedIndexingMap( 1060 genericOp.getOutputOperand(resultNumber)), 1061 expansionInfo); 1062 resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>( 1063 genericOp.getLoc(), opResult.getType(), 1064 fusedOp->getResult(resultNumber), reassociation)); 1065 } else { 1066 resultVals.push_back(fusedOp->getResult(resultNumber)); 1067 } 1068 } 1069 // Assuming a single result. 1070 return resultVals; 1071 } 1072 1073 namespace { 1074 1075 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op, 1076 /// when the reshape op is collapsing dimensions. The dimensionality of the loop 1077 /// in the consumer is expanded. 1078 class FoldWithProducerReshapeOpByExpansion 1079 : public OpRewritePattern<GenericOp> { 1080 public: 1081 FoldWithProducerReshapeOpByExpansion( 1082 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1083 PatternBenefit benefit = 1) 1084 : OpRewritePattern<GenericOp>(context, benefit), 1085 controlFoldingReshapes(std::move(foldReshapes)) {} 1086 1087 LogicalResult matchAndRewrite(GenericOp genericOp, 1088 PatternRewriter &rewriter) const override { 1089 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 1090 tensor::CollapseShapeOp reshapeOp = 1091 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>(); 1092 if (!reshapeOp) 1093 continue; 1094 // Fold only if 1095 // - The tensor reshape op is folding. 1096 // - All constraints of fusing with reshape by expansion are met. 1097 if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || 1098 (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) 1099 continue; 1100 1101 Optional<SmallVector<Value>> replacementValues = 1102 fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); 1103 if (!replacementValues) 1104 return failure(); 1105 rewriter.replaceOp(genericOp, replacementValues.getValue()); 1106 return success(); 1107 } 1108 return failure(); 1109 } 1110 1111 private: 1112 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1113 }; 1114 1115 /// Pattern to fold a tensor_expand_shape op with its producer generic op 1116 /// by expanding the dimensionality of the loop in the producer op. 1117 struct FoldReshapeWithGenericOpByExpansion 1118 : public OpRewritePattern<tensor::ExpandShapeOp> { 1119 1120 FoldReshapeWithGenericOpByExpansion( 1121 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1122 PatternBenefit benefit = 1) 1123 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), 1124 controlFoldingReshapes(std::move(foldReshapes)) {} 1125 1126 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, 1127 PatternRewriter &rewriter) const override { 1128 // Fold only if all constraints of fusing with reshape by expansion are met. 1129 GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>(); 1130 if (!producer || producer.getNumOutputs() != 1 || 1131 !isFusableWithReshapeByDimExpansion(producer, 1132 producer.getOutputOperand(0)) || 1133 !controlFoldingReshapes(producer->getResult(0), 1134 reshapeOp->getOpOperand(0))) 1135 return failure(); 1136 Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion( 1137 producer, reshapeOp, producer.getOutputOperand(0), rewriter); 1138 if (!replacementValues) 1139 return failure(); 1140 rewriter.replaceOp(reshapeOp, replacementValues.getValue()); 1141 return success(); 1142 } 1143 1144 private: 1145 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1146 }; 1147 } // namespace 1148 1149 //===---------------------------------------------------------------------===// 1150 // Methods and patterns to fuse reshape with linalg.generic operations by 1151 // contraction of dimensions. 1152 //===---------------------------------------------------------------------===// 1153 1154 /// For an `indexingMap` that is a projected permutation, if the range is to be 1155 /// collapsed using the given `reassociation`, get the reassociation in the 1156 /// domain that would keep the map a projected permutation. 1157 static SmallVector<ReassociationIndices> 1158 getDomainReassociation(AffineMap indexingMap, 1159 ArrayRef<ReassociationIndices> rangeReassociation) { 1160 assert(indexingMap.isProjectedPermutation() && 1161 "expected projected permutation map"); 1162 unsigned counter = 0; 1163 SmallVector<ReassociationIndices> domainReassociation; 1164 llvm::SmallDenseSet<unsigned, 4> processedDomainDims; 1165 // Iterate over the reassociation indices. 1166 for (ReassociationIndicesRef foldedRangeDims : rangeReassociation) { 1167 ReassociationIndices foldedDomainDims; 1168 for (auto rangeDim : foldedRangeDims) { 1169 (void)rangeDim; 1170 AffineDimExpr dimExpr = 1171 indexingMap.getResult(counter++).cast<AffineDimExpr>(); 1172 foldedDomainDims.push_back(dimExpr.getPosition()); 1173 processedDomainDims.insert(dimExpr.getPosition()); 1174 } 1175 domainReassociation.emplace_back(std::move(foldedDomainDims)); 1176 } 1177 // Fill in the missing domain dims. 1178 for (auto dim : llvm::seq<unsigned>(0, indexingMap.getNumDims())) { 1179 if (processedDomainDims.count(dim)) 1180 continue; 1181 ReassociationIndices vec = {dim}; 1182 domainReassociation.emplace_back(std::move(vec)); 1183 } 1184 1185 // Sort the reassociation using the first dimension of the folded range to 1186 // not create unnecessary transposes. 1187 llvm::sort(domainReassociation, 1188 [](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) { 1189 return lhs[0] < rhs[0]; 1190 }); 1191 return domainReassociation; 1192 } 1193 1194 /// For a given `dimSequence`, check if the sequence is conserved in the 1195 /// `indexingMap`. `indexingMap` is expected to be a projected permutation. 1196 /// Non-existence of the sequence returns true as well. 1197 static bool isDimSequencePreserved(AffineMap indexingMap, 1198 ReassociationIndicesRef dimSequence) { 1199 assert(!dimSequence.empty() && 1200 "expected non-empty list for dimension sequence"); 1201 assert(indexingMap.isProjectedPermutation() && 1202 "expected indexing map to be projected permutation"); 1203 1204 llvm::SmallDenseSet<unsigned, 4> sequenceElements; 1205 sequenceElements.insert(dimSequence.begin(), dimSequence.end()); 1206 1207 unsigned dimSequenceStart = dimSequence[0]; 1208 for (const auto &expr : enumerate(indexingMap.getResults())) { 1209 unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition(); 1210 // 1. Check if this start of the sequence. 1211 if (dimInMapStart == dimSequenceStart) { 1212 if (expr.index() + dimSequence.size() > indexingMap.getNumResults()) 1213 return false; 1214 // 1a. Check if sequence is preserved. 1215 for (const auto &dimInSequence : enumerate(dimSequence)) { 1216 unsigned dimInMap = 1217 indexingMap.getResult(expr.index() + dimInSequence.index()) 1218 .cast<AffineDimExpr>() 1219 .getPosition(); 1220 if (dimInMap != dimInSequence.value()) 1221 return false; 1222 } 1223 // Found the sequence. Projected permutation 1224 // enforces that all AffineDimExprs in the result are unique, so no 1225 // further checks are needed. 1226 return true; 1227 } 1228 // 2. If position in the expr (which is of type AffineDimExpr) is part 1229 // of sequence, return false here. This implies the entire sequence does not 1230 // exist in the indexing map. 1231 if (sequenceElements.count(dimInMapStart)) 1232 return false; 1233 } 1234 // 3. No element of sequence found. Return true. 1235 return true; 1236 } 1237 1238 // Check if a generic op can be fused along an operand by collapsing dimensions. 1239 static bool isFusableWithReshapeByDimCollapse( 1240 GenericOp genericOp, OpOperand *fusableOperand, 1241 ArrayRef<ReassociationIndices> reassociation) { 1242 // Some basic checks for this fusion to be valid. 1243 if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1) 1244 return false; 1245 1246 if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) { 1247 return map.isProjectedPermutation(); 1248 })) { 1249 return false; 1250 } 1251 1252 // Get the reassociation for the iteration space. 1253 SmallVector<ReassociationIndices> iterationReassociation = 1254 getDomainReassociation(genericOp.getTiedIndexingMap(fusableOperand), 1255 reassociation); 1256 if (iterationReassociation.empty()) { 1257 // If the domain reassociation indices is empty, then this is a scalar op. 1258 // Nothing to do. 1259 return false; 1260 } 1261 1262 auto iteratorTypes = genericOp.iterator_types().getValue(); 1263 ArrayRef<Attribute> iteratorTypesRef(iteratorTypes); 1264 for (ReassociationIndicesRef foldedIterDims : iterationReassociation) { 1265 // Check that for all indexing maps, the folded dimensions sequence is 1266 // preserved. 1267 if (!llvm::all_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) { 1268 return isDimSequencePreserved(indexingMap, foldedIterDims); 1269 })) 1270 return false; 1271 unsigned startDim = foldedIterDims[0]; 1272 ArrayRef<Attribute> foldedIteratorTypes = 1273 iteratorTypesRef.drop_front(startDim).take_front(foldedIterDims.size()); 1274 // Check that all folded iterator types are either all parallel, or all 1275 // reduction. 1276 if (!llvm::all_of( 1277 foldedIteratorTypes, 1278 [](Attribute attr) { return isParallelIterator(attr); }) && 1279 !llvm::all_of(foldedIteratorTypes, 1280 [](Attribute attr) { return isReductionIterator(attr); })) 1281 return false; 1282 } 1283 return true; 1284 } 1285 1286 /// Helper class to carry state while collapsing the `linalg.generic` op. 1287 namespace { 1288 class CollapsingInfo { 1289 public: 1290 CollapsingInfo(SmallVector<ReassociationIndices> &&reassociation) { 1291 iterationReassociation = std::move(reassociation); 1292 for (const auto &foldedIterDims : enumerate(iterationReassociation)) { 1293 foldedDimStartToSequenceMap[foldedIterDims.value()[0]] = 1294 foldedIterDims.index(); 1295 } 1296 } 1297 1298 // Returns the iteration space reassociation. 1299 ArrayRef<ReassociationIndices> getReassociationIndices() { 1300 return iterationReassociation; 1301 } 1302 1303 // Returns true if the given dimension is the start of a sequence of folded 1304 // dimensions. 1305 bool isDimStartOfFoldedDims(unsigned dim) { 1306 return foldedDimStartToSequenceMap.count(dim); 1307 } 1308 1309 // Return the folded dimensions starting at `dim`. 1310 ReassociationIndicesRef getFoldedDimsStartingAt(unsigned dim) { 1311 assert(foldedDimStartToSequenceMap.count(dim) && 1312 "invalid start dim of folded dim " 1313 "sequence"); 1314 return iterationReassociation[foldedDimStartToSequenceMap[dim]]; 1315 } 1316 1317 // For a dim in the original op, return the dim in the collapsed op, that it 1318 // is mapped to. Expectes `dim` to be start of a folded dimension sequence. 1319 unsigned getDimInCollapsedOpForStartOfFoldedDims(unsigned dim) { 1320 assert(foldedDimStartToSequenceMap.count(dim) && 1321 "invalid start dim of folded dim sequence"); 1322 return foldedDimStartToSequenceMap[dim]; 1323 } 1324 1325 private: 1326 /// Reassociation describing the folded iteration space dimensions. 1327 SmallVector<ReassociationIndices> iterationReassociation; 1328 1329 /// Map from the starting dimensions of the folded dimension sequences to 1330 /// their index in `iterationReassociation`. 1331 llvm::DenseMap<unsigned, unsigned> foldedDimStartToSequenceMap; 1332 }; 1333 } // namespace 1334 1335 /// Get the iterator types for the collapsed operation given the original 1336 /// iterator types and collapsed dimensions. 1337 static SmallVector<StringRef> 1338 getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes, 1339 CollapsingInfo &collapsingInfo) { 1340 SmallVector<StringRef> collapsedIteratorTypes; 1341 for (ReassociationIndicesRef foldedIterDims : 1342 collapsingInfo.getReassociationIndices()) { 1343 assert(!foldedIterDims.empty() && 1344 "reassociation indices expected to have non-empty sets"); 1345 // Just pick the iterator type of the first folded dim. Pre-condition checks 1346 // expected to have checked that iterator types of all folded dimensions are 1347 // the same. 1348 collapsedIteratorTypes.push_back( 1349 iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue()); 1350 } 1351 return collapsedIteratorTypes; 1352 } 1353 1354 /// Compute the indexing map in the collapsed op that corresponds to the given 1355 /// `indexingMap` of the original operation. 1356 static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, 1357 CollapsingInfo &collapsingInfo) { 1358 MLIRContext *context = indexingMap.getContext(); 1359 assert(indexingMap.isProjectedPermutation() && 1360 "expected indexing map to be projected permutation"); 1361 SmallVector<AffineExpr> resultExprs; 1362 for (auto expr : indexingMap.getResults()) { 1363 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 1364 if (collapsingInfo.isDimStartOfFoldedDims(dim)) { 1365 resultExprs.push_back(getAffineDimExpr( 1366 collapsingInfo.getDimInCollapsedOpForStartOfFoldedDims(dim), 1367 context)); 1368 } 1369 } 1370 return AffineMap::get(collapsingInfo.getReassociationIndices().size(), 0, 1371 resultExprs, context); 1372 } 1373 1374 /// Return the `reassociation` indices to use to collapse the operand when the 1375 /// iteration space of a generic op is collapsed. 1376 static SmallVector<ReassociationIndices> 1377 getOperandReassociation(AffineMap indexingMap, CollapsingInfo &collapsingInfo) { 1378 unsigned counter = 0; 1379 SmallVector<ReassociationIndices> operandReassociation; 1380 for (auto expr : indexingMap.getResults()) { 1381 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 1382 if (collapsingInfo.isDimStartOfFoldedDims(dim)) { 1383 unsigned numFoldedDims = 1384 collapsingInfo.getFoldedDimsStartingAt(dim).size(); 1385 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims); 1386 operandReassociation.emplace_back(range.begin(), range.end()); 1387 counter += numFoldedDims; 1388 } 1389 } 1390 return operandReassociation; 1391 } 1392 1393 /// Get the new value to use for a given `OpOperand` in the collapsed operation. 1394 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, 1395 OpOperand *opOperand, 1396 CollapsingInfo &collapsingInfo, 1397 OpBuilder &builder) { 1398 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 1399 SmallVector<ReassociationIndices> operandReassociation = 1400 getOperandReassociation(indexingMap, collapsingInfo); 1401 1402 // If the number of entries in the reassocation for the operand is same as the 1403 // number of results of the indexing map, then nothing to do for this operand. 1404 Value operand = opOperand->get(); 1405 if (operandReassociation.size() == indexingMap.getNumResults()) 1406 return operand; 1407 1408 // Insert a reshape to collapse the dimensions. 1409 auto reshapeOp = builder.create<tensor::CollapseShapeOp>( 1410 loc, operand, operandReassociation); 1411 return reshapeOp.getResult(); 1412 } 1413 1414 /// Modify the `linalg.index` operations in the original generic op, to its 1415 /// value in the collapsed operation. 1416 void generateCollapsedIndexingRegion(Location loc, Block *block, 1417 CollapsingInfo &collapsingInfo, 1418 ValueRange loopRange, 1419 PatternRewriter &rewriter) { 1420 OpBuilder::InsertionGuard g(rewriter); 1421 rewriter.setInsertionPointToStart(block); 1422 1423 // Collect all the original index ops. 1424 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>()); 1425 1426 // For each folded dimension list resolve the original induction variable 1427 // values in terms of the folded dimension induction variable. 1428 // i_{folded} = (i_0 * d1 + i1) * d2 + i2. 1429 // can be inverted to 1430 // i2 = i_{folded} % d2 1431 // i1 = (i_{folded} / d2) % d1 1432 // i0 = i_{folded} / (d1 * d2) 1433 llvm::DenseMap<unsigned, Value> indexReplacementVals; 1434 for (auto &foldedDims : enumerate(collapsingInfo.getReassociationIndices())) { 1435 ReassociationIndicesRef foldedDimsRef(foldedDims.value()); 1436 Value newIndexVal = 1437 rewriter.create<linalg::IndexOp>(loc, foldedDims.index()); 1438 for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) { 1439 indexReplacementVals[dim] = 1440 rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]); 1441 newIndexVal = 1442 rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]); 1443 } 1444 indexReplacementVals[foldedDims.value().front()] = newIndexVal; 1445 } 1446 1447 for (auto indexOp : indexOps) { 1448 auto dim = indexOp.dim(); 1449 rewriter.replaceOp(indexOp, indexReplacementVals[dim]); 1450 } 1451 } 1452 1453 /// Implementation of fusion with reshape operation by collapsing dimensions. 1454 static Optional<SmallVector<Value>> 1455 fuseWithReshapeByCollapsing(GenericOp genericOp, Operation *reshapeOp, 1456 OpOperand *fusableOpOperand, 1457 PatternRewriter &rewriter) { 1458 SmallVector<ReassociationIndices> reassociation = 1459 isa<tensor::CollapseShapeOp>(reshapeOp) 1460 ? cast<tensor::CollapseShapeOp>(reshapeOp).getReassociationIndices() 1461 : cast<tensor::ExpandShapeOp>(reshapeOp).getReassociationIndices(); 1462 assert(isFusableWithReshapeByDimCollapse(genericOp, fusableOpOperand, 1463 reassociation) && 1464 "preconditions for fusing with reshape by collapse failed"); 1465 1466 CollapsingInfo collapsingInfo(getDomainReassociation( 1467 genericOp.getTiedIndexingMap(fusableOpOperand), reassociation)); 1468 // Check for trivial no transformation cases. In that case return nothing. 1469 if (collapsingInfo.getReassociationIndices().size() == 1470 genericOp.getNumLoops()) 1471 return llvm::None; 1472 1473 // Get the iterator types for the operand. 1474 SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes( 1475 genericOp.iterator_types().getValue(), collapsingInfo); 1476 1477 // Get the indexing maps. 1478 auto indexingMaps = llvm::to_vector( 1479 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap map) { 1480 return getCollapsedOpIndexingMap(map, collapsingInfo); 1481 })); 1482 1483 Location loc = 1484 rewriter.getFusedLoc({genericOp->getLoc(), reshapeOp->getLoc()}); 1485 1486 // Get the input operands. 1487 auto inputOperands = llvm::to_vector( 1488 llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) { 1489 return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo, 1490 rewriter); 1491 })); 1492 1493 // Get the output operands and result types. 1494 SmallVector<Type> resultTypes; 1495 SmallVector<Value> outputOperands; 1496 resultTypes.reserve(genericOp.getNumOutputs()); 1497 outputOperands.reserve(genericOp.getNumOutputs()); 1498 for (OpOperand *output : genericOp.getOutputOperands()) { 1499 Value newOutput = 1500 getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter); 1501 outputOperands.push_back(newOutput); 1502 resultTypes.push_back(newOutput.getType()); 1503 } 1504 1505 // Create the generic op. 1506 auto collapsedGenericOp = rewriter.create<linalg::GenericOp>( 1507 loc, resultTypes, inputOperands, outputOperands, indexingMaps, 1508 iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); 1509 Block *origOpBlock = &genericOp->getRegion(0).front(); 1510 Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front(); 1511 rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, 1512 collapsedOpBlock->getArguments()); 1513 1514 if (collapsedGenericOp.hasIndexSemantics()) { 1515 // Collect the loop range of the generic op. 1516 OpBuilder::InsertionGuard g(rewriter); 1517 rewriter.setInsertionPoint(collapsedGenericOp); 1518 SmallVector<Range> loopRanges = 1519 cast<LinalgOp>(genericOp.getOperation()) 1520 .createLoopRanges(rewriter, genericOp.getLoc()); 1521 assert(llvm::all_of(loopRanges, 1522 [](Range range) { 1523 return matchPattern(range.offset, m_Zero()) && 1524 matchPattern(range.stride, m_One()); 1525 }) && 1526 "expected all loop ranges to have zero start and unit stride"); 1527 SmallVector<Value> loopBound = llvm::to_vector( 1528 llvm::map_range(loopRanges, [](Range range) { return range.size; })); 1529 generateCollapsedIndexingRegion(loc, 1530 &collapsedGenericOp->getRegion(0).front(), 1531 collapsingInfo, loopBound, rewriter); 1532 } 1533 1534 // Insert expanding reshape for the result to get back the original result 1535 // type. 1536 SmallVector<Value> results; 1537 for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) { 1538 Value collapsedOpResult = 1539 collapsedGenericOp->getResult(originalResult.index()); 1540 auto originalResultType = 1541 originalResult.value().getType().cast<ShapedType>(); 1542 auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>(); 1543 if (collapsedOpResultType.getRank() != originalResultType.getRank()) { 1544 AffineMap indexingMap = 1545 genericOp.getTiedIndexingMapForResult(originalResult.value()); 1546 SmallVector<ReassociationIndices> reassociation = 1547 getOperandReassociation(indexingMap, collapsingInfo); 1548 Value result = rewriter.create<tensor::ExpandShapeOp>( 1549 loc, originalResultType, collapsedOpResult, reassociation); 1550 results.push_back(result); 1551 } else { 1552 results.push_back(collapsedOpResult); 1553 } 1554 } 1555 return results; 1556 } 1557 1558 namespace { 1559 1560 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by 1561 /// contracting dimensions of the loop. 1562 class FoldWithProducerReshapeOpByCollapsing 1563 : public OpRewritePattern<GenericOp> { 1564 public: 1565 FoldWithProducerReshapeOpByCollapsing( 1566 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1567 PatternBenefit benefit = 1) 1568 : OpRewritePattern<GenericOp>(context, benefit), 1569 controlFoldingReshapes(std::move(foldReshapes)) {} 1570 1571 LogicalResult matchAndRewrite(GenericOp genericOp, 1572 PatternRewriter &rewriter) const override { 1573 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 1574 tensor::ExpandShapeOp reshapeOp = 1575 opOperand->get().getDefiningOp<tensor::ExpandShapeOp>(); 1576 if (!reshapeOp) 1577 continue; 1578 1579 if (!isFusableWithReshapeByDimCollapse( 1580 genericOp, opOperand, reshapeOp.getReassociationIndices()) || 1581 !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) { 1582 continue; 1583 } 1584 1585 Optional<SmallVector<Value>> replacements = fuseWithReshapeByCollapsing( 1586 genericOp, reshapeOp, opOperand, rewriter); 1587 if (!replacements) { 1588 return rewriter.notifyMatchFailure( 1589 genericOp, "failed to do the fusion by collapsing transformation"); 1590 } 1591 1592 rewriter.replaceOp(genericOp, replacements.getValue()); 1593 return success(); 1594 } 1595 return failure(); 1596 } 1597 1598 private: 1599 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1600 }; 1601 } // namespace 1602 1603 //===---------------------------------------------------------------------===// 1604 // Methods and patterns to convert tensor.expand_shape -> linalg.generic 1605 // into linalg.generic -> tensor.expand_shape, i.e. push the reshape down. 1606 //===---------------------------------------------------------------------===// 1607 1608 // TODO(ravishankarm): This pattern is to be deprecated in favor of fusion by 1609 // collapsing that provides a more general functionality. This pattern is very 1610 // specific to a particular use case. The fusion by collapsing can provide the 1611 // same control to clients using the control function there. 1612 1613 static SmallVector<ReassociationIndices> 1614 getReassociationIndices(ArrayRef<AffineMap> maps) { 1615 SmallVector<ReassociationIndices> reassociation; 1616 for (AffineMap map : maps) { 1617 ReassociationIndices indices; 1618 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 1619 unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition(); 1620 indices.push_back(pos); 1621 } 1622 reassociation.push_back(indices); 1623 } 1624 return reassociation; 1625 } 1626 1627 namespace { 1628 /// Pattern to move rank reducing reshape after an elementwise linalg generic 1629 /// op. This is useful to expose more fusion opportunities between named ops and 1630 /// generic ops. This can only be done if there is no broadcast or permuation 1631 /// within the dimensions we need to merge. 1632 /// 1633 /// For example, 1634 /// 1635 /// %0 = tensor.expand_shape %A [[0, 1], [2]] 1636 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 1637 /// %2 = linalg.generic {indexing_maps = [ 1638 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1639 /// affine_map<(d0, d1, d2) -> (d2)>, 1640 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = 1641 /// ["parallel", "parallel", "parallel"]} { 1642 /// } -> tensor<112x112x16xf32> 1643 /// 1644 /// into 1645 /// 1646 /// %2 = linalg.generic {indexing_maps = [ 1647 /// affine_map<(d0, d1) -> (d0, d1)>, 1648 /// affine_map<(d0, d1) -> (d1)>, 1649 /// affine_map<(d0, d1) -> (d0, d1)>], 1650 /// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 1651 /// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) { 1652 /// } -> tensor<12544x16xf32> 1653 /// %3 = tensor.expand_shape %2 [[0, 1], [2]] 1654 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 1655 struct PushExpandingReshape : public OpRewritePattern<GenericOp> { 1656 using OpRewritePattern<GenericOp>::OpRewritePattern; 1657 1658 LogicalResult matchAndRewrite(GenericOp genericOp, 1659 PatternRewriter &rewriter) const override { 1660 // Only apply to elementwise linalg on tensor. 1661 if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() || 1662 genericOp.getNumParallelLoops() != genericOp.getNumLoops()) 1663 return failure(); 1664 // Only support identity output maps. It could be extended to permuations if 1665 // needed. 1666 if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) { 1667 return !genericOp.getTiedIndexingMap(opOperand).isIdentity(); 1668 })) 1669 return failure(); 1670 int64_t destRank = genericOp.getNumParallelLoops(); 1671 SmallVector<Value> newOperands = genericOp.getInputOperands(); 1672 tensor::ExpandShapeOp reshapeFound; 1673 // 1. Look for tensor_expand_shape operands and figure out save the 1674 // dimensions merged. 1675 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 1676 for (const auto &en : llvm::enumerate(inputOperands)) { 1677 auto reshapeOp = 1678 en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>(); 1679 if (!reshapeOp) 1680 continue; 1681 // TODO: We could support non-identity map as long as the merged 1682 // dimensions are still contiguous. 1683 if (!genericOp.getTiedIndexingMap(en.value()).isIdentity()) 1684 continue; 1685 if (reshapeFound) { 1686 // Only support a second reshape op if it has the same reassociate maps. 1687 if (reshapeFound.getReassociationMaps() == 1688 reshapeOp.getReassociationMaps()) 1689 newOperands[en.index()] = reshapeOp.src(); 1690 continue; 1691 } 1692 reshapeFound = reshapeOp; 1693 newOperands[en.index()] = reshapeOp.src(); 1694 } 1695 if (!reshapeFound) 1696 return failure(); 1697 1698 // Calculate the reassociation indices and rassociated reverse map. 1699 SmallVector<ReassociationIndices> reassociation = 1700 getReassociationIndices(reshapeFound.getReassociationMaps()); 1701 SmallVector<unsigned> remap(destRank); 1702 for (auto &indices : llvm::enumerate(reassociation)) { 1703 for (int64_t index : indices.value()) { 1704 remap[index] = indices.index(); 1705 } 1706 } 1707 // 2. Verify that we can merge the dimensions in the linalg and that we 1708 // don't need to create new reshapes operands. Inserting new reshape 1709 // operands would defeat the purpose of the transformation. 1710 for (const auto &en : llvm::enumerate(inputOperands)) { 1711 if (en.value()->get() == newOperands[en.index()]) { 1712 AffineMap map = genericOp.getTiedIndexingMap(en.value()); 1713 for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { 1714 if (reassociation[remap[map.getDimPosition(i)]].size() > 1) 1715 return failure(); 1716 } 1717 } 1718 } 1719 1720 // 3. Calculate the affine map remapping and the reassociation to apply to 1721 // output tensors. 1722 SmallVector<AffineMap> newMaps; 1723 unsigned newRank = reassociation.size(); 1724 for (auto map : genericOp.getIndexingMaps()) { 1725 SmallVector<AffineExpr> newExprs; 1726 for (auto expr : map.getResults()) { 1727 unsigned position = expr.template cast<AffineDimExpr>().getPosition(); 1728 // Skip dimension merged except for the last of the group. 1729 if (reassociation[remap[position]].back() == position) { 1730 newExprs.push_back( 1731 getAffineDimExpr(remap[position], genericOp.getContext())); 1732 } 1733 } 1734 newMaps.push_back( 1735 AffineMap::get(newRank, 0, newExprs, genericOp.getContext())); 1736 } 1737 1738 // 4. Reshape the output tensors. 1739 SmallVector<Value> newOutputs; 1740 SmallVector<Type> newOutputTypes; 1741 for (auto output : genericOp.outputs()) { 1742 auto newOutputType = RankedTensorType::get( 1743 reshapeFound.getSrcType().getShape(), 1744 output.getType().template cast<RankedTensorType>().getElementType()); 1745 Value newOutput = rewriter.create<tensor::CollapseShapeOp>( 1746 genericOp->getLoc(), newOutputType, output, reassociation); 1747 newOutputTypes.push_back(newOutputType); 1748 newOutputs.push_back(newOutput); 1749 } 1750 // 5. Create a new generic op with lowerer rank. 1751 SmallVector<StringRef> iteratorTypes(newRank, 1752 getParallelIteratorTypeName()); 1753 auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes, 1754 newOperands, newOutputs, newMaps, 1755 iteratorTypes); 1756 rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), 1757 newOp.region().begin()); 1758 // 6. Reshape the so that the type matches the uses. 1759 SmallVector<Value> newResults; 1760 for (const auto &result : llvm::enumerate(newOp->getResults())) { 1761 newResults.push_back(rewriter.create<tensor::ExpandShapeOp>( 1762 genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()], 1763 result.value(), reassociation)); 1764 } 1765 rewriter.replaceOp(genericOp, newResults); 1766 return success(); 1767 } 1768 }; 1769 } // namespace 1770 1771 //===---------------------------------------------------------------------===// 1772 // Methods and patterns that fuse constants with linalg.generic operations. 1773 //===---------------------------------------------------------------------===// 1774 1775 namespace { 1776 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not 1777 /// handle cases where the constant is not single-valued. 1778 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> { 1779 public: 1780 FoldScalarOrSplatConstant(MLIRContext *context, 1781 ControlElementwiseOpsFusionFn &fun, 1782 PatternBenefit benefit = 1) 1783 : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {} 1784 1785 LogicalResult matchAndRewrite(GenericOp genericOp, 1786 PatternRewriter &rewriter) const override { 1787 if (!genericOp.hasTensorSemantics()) 1788 return failure(); 1789 for (OpOperand *opOperand : genericOp.getInputOperands()) { 1790 Operation *def = opOperand->get().getDefiningOp(); 1791 Attribute constantAttr; 1792 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { 1793 { 1794 DenseElementsAttr splatAttr; 1795 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) && 1796 splatAttr.isSplat() && 1797 splatAttr.getType().getElementType().isIntOrFloat()) { 1798 constantAttr = splatAttr.getSplatValue<Attribute>(); 1799 return true; 1800 } 1801 } 1802 { 1803 IntegerAttr intAttr; 1804 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) { 1805 constantAttr = intAttr; 1806 return true; 1807 } 1808 } 1809 { 1810 FloatAttr floatAttr; 1811 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) { 1812 constantAttr = floatAttr; 1813 return true; 1814 } 1815 } 1816 return false; 1817 }; 1818 1819 auto resultValue = opOperand->get().dyn_cast<OpResult>(); 1820 if (!def || !resultValue || !isScalarOrSplatConstantOp(def) || 1821 !controlFn(resultValue, *opOperand)) 1822 continue; 1823 1824 // The operands and the indexing_maps of the fused operation the same as 1825 // the operands and indexing_maps of the generic operations with the 1826 // values at the constant index dropped. 1827 SmallVector<AffineMap> fusedIndexMaps; 1828 SmallVector<Value> fusedOperands; 1829 SmallVector<Location> fusedLocs{genericOp.getLoc()}; 1830 fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); 1831 fusedOperands.reserve(genericOp.getNumInputs()); 1832 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); 1833 for (OpOperand *inputOperand : genericOp.getInputOperands()) { 1834 if (inputOperand == opOperand) 1835 continue; 1836 Value inputValue = inputOperand->get(); 1837 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); 1838 fusedOperands.push_back(inputValue); 1839 fusedLocs.push_back(inputValue.getLoc()); 1840 } 1841 for (OpOperand *outputOperand : genericOp.getOutputOperands()) 1842 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); 1843 1844 // Check if the operation shapes to loops map is computable. 1845 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1846 return rewriter.notifyMatchFailure( 1847 genericOp, "fused op loop bound computation failed"); 1848 } 1849 1850 // Create a constant scalar value from the splat constant. 1851 Value scalarConstant = rewriter.create<arith::ConstantOp>( 1852 def->getLoc(), constantAttr, constantAttr.getType()); 1853 1854 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 1855 auto fusedOp = rewriter.create<GenericOp>( 1856 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), 1857 /*inputs=*/fusedOperands, 1858 /*outputs=*/outputOperands, 1859 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1860 genericOp.iterator_types(), 1861 /*doc=*/nullptr, 1862 /*library_call=*/nullptr); 1863 1864 // Map the block argument corresponding to the replaced argument with the 1865 // scalar constant. 1866 Region ®ion = genericOp->getRegion(0); 1867 Block &entryBlock = *region.begin(); 1868 BlockAndValueMapping mapping; 1869 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), 1870 scalarConstant); 1871 Region &fusedRegion = fusedOp->getRegion(0); 1872 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), 1873 mapping); 1874 rewriter.replaceOp(genericOp, fusedOp->getResults()); 1875 return success(); 1876 } 1877 return failure(); 1878 } 1879 1880 private: 1881 ControlElementwiseOpsFusionFn controlFn; 1882 }; 1883 1884 /// Base class for constant folding linalg.generic ops with N inputs, 1 output, 1885 /// and permutation indexing maps. 1886 /// 1887 /// `ConcreteType` should provide methods with signatures 1888 /// 1889 /// ```c++ 1890 /// bool matchIndexingMaps(GenericOp genericOp) const; 1891 /// RegionComputationFn getRegionComputeFn(GenericOp) const; 1892 /// ``` 1893 /// 1894 /// The latter inspects the region and returns the computation inside as a 1895 /// functor. The functor will be invoked with constant elements for all inputs 1896 /// and should return the corresponding computea constant element for output. 1897 template <typename ConcreteType> 1898 class FoldConstantBase : public OpRewritePattern<GenericOp> { 1899 public: 1900 struct APIntOrFloat { 1901 Optional<APInt> apInt; 1902 Optional<APFloat> apFloat; 1903 }; 1904 struct APIntOrFloatArray { 1905 SmallVector<APInt> apInts; 1906 SmallVector<APFloat> apFloats; 1907 }; 1908 using RegionComputationFn = 1909 std::function<APIntOrFloat(const APIntOrFloatArray &)>; 1910 1911 FoldConstantBase(MLIRContext *context, 1912 const ControlElementwiseOpsFusionFn &controlFn, 1913 PatternBenefit benefit = 1) 1914 : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {} 1915 1916 LogicalResult matchAndRewrite(GenericOp genericOp, 1917 PatternRewriter &rewriter) const override { 1918 if (genericOp.hasBufferSemantics()) 1919 return failure(); 1920 1921 // Only support ops generating one output for now. 1922 if (genericOp.getNumOutputs() != 1) 1923 return failure(); 1924 1925 auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>(); 1926 // Require the output types to be static give we are generating constants. 1927 if (!outputType || !outputType.hasStaticShape()) 1928 return failure(); 1929 1930 if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) { 1931 return operand->get().getType().isa<ShapedType>(); 1932 })) 1933 return failure(); 1934 1935 // Make sure all element types are the same. 1936 auto getOperandElementType = [](OpOperand *operand) { 1937 return operand->get().getType().cast<ShapedType>().getElementType(); 1938 }; 1939 if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(), 1940 getOperandElementType))) 1941 return failure(); 1942 1943 // We can only handle the case where we have int/float elements. 1944 auto elementType = outputType.getElementType(); 1945 if (!elementType.isIntOrFloat()) 1946 return failure(); 1947 1948 // Require all indexing maps to be permutations for now. This is common and 1949 // it simplifies input/output access greatly: we can do the data shuffling 1950 // entirely in the compiler, without needing to turn all indices into 1951 // Values, and then do affine apply on them, and then match back the 1952 // constant again. 1953 if (!llvm::all_of(genericOp.getIndexingMaps(), 1954 [](AffineMap map) { return map.isPermutation(); })) 1955 return failure(); 1956 1957 for (OpOperand *operand : genericOp.getOutputOperands()) { 1958 if (genericOp.payloadUsesValueFromOperand(operand)) 1959 return failure(); 1960 } 1961 1962 // Further check the indexing maps are okay for the ConcreteType. 1963 if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp)) 1964 return failure(); 1965 1966 // Defer to the concrete type to check the region and discover the 1967 // computation inside. 1968 RegionComputationFn computeFn = 1969 static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp); 1970 if (!computeFn) 1971 return failure(); 1972 1973 // All inputs should be constants. 1974 int numInputs = genericOp.getNumInputs(); 1975 SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs); 1976 for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) { 1977 if (!matchPattern(operand.value()->get(), 1978 m_Constant(&inputValues[operand.index()]))) 1979 return failure(); 1980 } 1981 1982 // Identified this as a potential candidate for folding. Now check the 1983 // policy to see whether we are allowed to proceed. 1984 for (int i = 0; i < numInputs; ++i) { 1985 OpOperand *consumer = genericOp.getInputOperand(i); 1986 OpResult producer = consumer->get().cast<OpResult>(); 1987 if (!controlFn(producer, *consumer)) 1988 return failure(); 1989 } 1990 1991 auto linalgOp = cast<LinalgOp>(genericOp.getOperation()); 1992 SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes(); 1993 int64_t numElements = outputType.getNumElements(); 1994 1995 // Use APInt/APFloat instead of Attribute here for constructing the output. 1996 // This helps to avoid blowing up compiler memory usage: Attributes would 1997 // unify the following cases but they have lifetime as the MLIRContext. 1998 SmallVector<APInt> intOutputValues; 1999 SmallVector<APFloat> fpOutputValues; 2000 if (elementType.template isa<FloatType>()) 2001 fpOutputValues.resize(numElements, APFloat(0.f)); 2002 else 2003 intOutputValues.resize(numElements); 2004 2005 // Return the constant dim positions from the given permutation map. 2006 auto getDimPositions = [](AffineMap map) { 2007 SmallVector<unsigned> dims; 2008 dims.reserve(map.getNumResults()); 2009 for (AffineExpr result : map.getResults()) { 2010 dims.push_back(result.cast<AffineDimExpr>().getPosition()); 2011 } 2012 return dims; 2013 }; 2014 2015 SmallVector<SmallVector<unsigned>> inputDims; 2016 for (int i = 0; i < numInputs; ++i) 2017 inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i])); 2018 auto outputDims = getDimPositions(genericOp.getIndexingMaps().back()); 2019 auto outputShape = outputType.getShape(); 2020 2021 // Allocate small vectors for index delinearization. Initial values do not 2022 // matter here as they will be overwritten later. 2023 SmallVector<uint64_t> indices(loopBounds.size(), 0); 2024 SmallVector<uint64_t> dstIndices(loopBounds.size(), 0); 2025 SmallVector<SmallVector<uint64_t>> srcIndices( 2026 numInputs, SmallVector<uint64_t>(loopBounds.size(), 0)); 2027 SmallVector<uint64_t> srcLinearIndices(numInputs, 0); 2028 uint64_t dstLinearIndex = 0; 2029 2030 // Allocate spaces for compute function inputs. Initial values do not matter 2031 // here as they will be overwritten later. 2032 APIntOrFloatArray computeFnInputs; 2033 2034 auto inputShapes = llvm::to_vector<4>( 2035 llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { 2036 return operand->get().getType().cast<ShapedType>().getShape(); 2037 })); 2038 2039 // Given a `linearIndex`, remap it to a linear index to access linalg op 2040 // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, 2041 // `srcLinearIndices`, `dstLinearIndex` in place. 2042 auto computeRemappedLinearIndex = [&](int linearIndex) { 2043 int totalCount = linearIndex; 2044 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 2045 indices[dim] = totalCount % loopBounds[dim]; 2046 totalCount /= loopBounds[dim]; 2047 } 2048 2049 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 2050 for (int i = 0; i < numInputs; ++i) 2051 srcIndices[i][dim] = indices[inputDims[i][dim]]; 2052 dstIndices[dim] = indices[outputDims[dim]]; 2053 } 2054 2055 dstLinearIndex = dstIndices.front(); 2056 for (int i = 0; i < numInputs; ++i) 2057 srcLinearIndices[i] = srcIndices[i].front(); 2058 2059 for (int dim = 1; dim < outputType.getRank(); ++dim) { 2060 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; 2061 for (int i = 0; i < numInputs; ++i) 2062 srcLinearIndices[i] = 2063 srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; 2064 } 2065 }; 2066 2067 bool isFloat = elementType.isa<FloatType>(); 2068 if (isFloat) { 2069 SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges; 2070 for (int i = 0; i < numInputs; ++i) 2071 inFpRanges.push_back(inputValues[i].getValues<APFloat>()); 2072 2073 computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); 2074 2075 // Transpose the input constant. Because we don't know its rank in 2076 // advance, we need to loop over the range [0, element count) and 2077 // delinearize the index. 2078 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 2079 computeRemappedLinearIndex(linearIndex); 2080 2081 // Collect constant elements for all inputs at this loop iteration. 2082 for (int i = 0; i < numInputs; ++i) 2083 computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; 2084 2085 // Invoke the computation to get the corresponding constant output 2086 // element. 2087 fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; 2088 } 2089 } else { 2090 SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges; 2091 for (int i = 0; i < numInputs; ++i) 2092 inIntRanges.push_back(inputValues[i].getValues<APInt>()); 2093 2094 computeFnInputs.apInts.resize(numInputs); 2095 2096 // Transpose the input constant. Because we don't know its rank in 2097 // advance, we need to loop over the range [0, element count) and 2098 // delinearize the index. 2099 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 2100 computeRemappedLinearIndex(linearIndex); 2101 2102 // Collect constant elements for all inputs at this loop iteration. 2103 for (int i = 0; i < numInputs; ++i) 2104 computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; 2105 2106 // Invoke the computation to get the corresponding constant output 2107 // element. 2108 intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; 2109 } 2110 } 2111 2112 DenseElementsAttr outputAttr = 2113 isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) 2114 : DenseElementsAttr::get(outputType, intOutputValues); 2115 2116 rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr); 2117 return success(); 2118 } 2119 2120 private: 2121 ControlElementwiseOpsFusionFn controlFn; 2122 }; 2123 2124 // Folds linalg.generic ops that are actually transposes on constant values. 2125 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> { 2126 using FoldConstantBase::FoldConstantBase; 2127 2128 bool matchIndexingMaps(GenericOp genericOp) const { 2129 // We should have one input and one output. 2130 return genericOp.getIndexingMaps().size() == 2; 2131 } 2132 2133 RegionComputationFn getRegionComputeFn(GenericOp genericOp) const { 2134 // Make sure the region only contains a yield op. 2135 Block &body = genericOp.region().front(); 2136 if (!llvm::hasSingleElement(body)) 2137 return nullptr; 2138 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); 2139 if (!yieldOp) 2140 return nullptr; 2141 2142 // The yield op should return the block argument corresponds to the input. 2143 for (Value yieldVal : yieldOp.values()) { 2144 auto yieldArg = yieldVal.dyn_cast<BlockArgument>(); 2145 if (!yieldArg || yieldArg.getOwner() != &body) 2146 return nullptr; 2147 if (yieldArg.getArgNumber() != 0) 2148 return nullptr; 2149 } 2150 2151 // No computation; just return the orginal value. 2152 return [](const APIntOrFloatArray &inputs) { 2153 if (inputs.apFloats.empty()) 2154 return APIntOrFloat{inputs.apInts.front(), llvm::None}; 2155 return APIntOrFloat{llvm::None, inputs.apFloats.front()}; 2156 }; 2157 } 2158 2159 ControlElementwiseOpsFusionFn controlFn; 2160 }; 2161 2162 } // namespace 2163 2164 //===---------------------------------------------------------------------===// 2165 // Miscellaneous patterns that help fusion. 2166 //===---------------------------------------------------------------------===// 2167 2168 namespace { 2169 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if 2170 /// the value of the `outs` operand is not used within the op. This is only 2171 /// implemented for `linalg.generic` operations for now, but should hold for all 2172 /// linalg structured ops. 2173 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> { 2174 using OpRewritePattern<GenericOp>::OpRewritePattern; 2175 2176 LogicalResult matchAndRewrite(GenericOp op, 2177 PatternRewriter &rewriter) const override { 2178 rewriter.startRootUpdate(op); 2179 bool modifiedOutput = false; 2180 Location loc = op.getLoc(); 2181 for (OpOperand *opOperand : op.getOutputOperands()) { 2182 if (!op.payloadUsesValueFromOperand(opOperand)) { 2183 Value operandVal = opOperand->get(); 2184 auto operandType = operandVal.getType().dyn_cast<RankedTensorType>(); 2185 if (!operandType) 2186 continue; 2187 2188 // If outs is sparse, leave it to the sparse compiler. 2189 if (sparse_tensor::getSparseTensorEncoding(operandVal.getType())) 2190 continue; 2191 2192 // If outs is already an `init_tensor` operation, nothing to do. 2193 auto definingOp = operandVal.getDefiningOp<InitTensorOp>(); 2194 if (definingOp) 2195 continue; 2196 modifiedOutput = true; 2197 SmallVector<Value> dynamicDims; 2198 for (const auto &dim : llvm::enumerate(operandType.getShape())) { 2199 if (dim.value() != ShapedType::kDynamicSize) 2200 continue; 2201 dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>( 2202 loc, operandVal, dim.index())); 2203 } 2204 Value initTensor = rewriter.create<InitTensorOp>( 2205 loc, dynamicDims, operandType.getShape(), 2206 operandType.getElementType()); 2207 op->setOperand(opOperand->getOperandNumber(), initTensor); 2208 } 2209 } 2210 if (!modifiedOutput) { 2211 rewriter.cancelRootUpdate(op); 2212 return failure(); 2213 } 2214 rewriter.finalizeRootUpdate(op); 2215 return success(); 2216 } 2217 }; 2218 } // namespace 2219 2220 //===---------------------------------------------------------------------===// 2221 // Methods that add patterns described in this file to a pattern list. 2222 //===---------------------------------------------------------------------===// 2223 2224 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( 2225 RewritePatternSet &patterns) { 2226 patterns.add< 2227 FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>, 2228 FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>, 2229 FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>>( 2230 patterns.getContext()); 2231 } 2232 2233 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( 2234 RewritePatternSet &patterns) { 2235 patterns 2236 .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>, 2237 FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>, 2238 FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>>( 2239 patterns.getContext()); 2240 } 2241 2242 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( 2243 RewritePatternSet &patterns, 2244 const ControlElementwiseOpsFusionFn &controlFoldingReshapes) { 2245 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(), 2246 controlFoldingReshapes); 2247 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), 2248 controlFoldingReshapes); 2249 } 2250 2251 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( 2252 RewritePatternSet &patterns, 2253 const ControlElementwiseOpsFusionFn &controlFoldingReshapes) { 2254 patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(), 2255 controlFoldingReshapes); 2256 } 2257 2258 void mlir::linalg::populateElementwiseOpsFusionPatterns( 2259 RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { 2260 auto *context = patterns.getContext(); 2261 patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant, 2262 FoldConstantTranspose>(context, 2263 options.controlElementwiseOpsFusionFn); 2264 patterns.add<RemoveOutsDependency>(context); 2265 populateSparseTensorRewriting(patterns); 2266 populateFoldReshapeOpsByExpansionPatterns(patterns, 2267 options.controlFoldingReshapesFn); 2268 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 2269 GenericOp::getCanonicalizationPatterns(patterns, context); 2270 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); 2271 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); 2272 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns( 2273 patterns); 2274 } 2275 2276 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { 2277 auto *context = patterns.getContext(); 2278 patterns.add<PushExpandingReshape>(context); 2279 } 2280 2281 //===---------------------------------------------------------------------===// 2282 // Passes 2283 //===---------------------------------------------------------------------===// 2284 2285 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, 2286 OpOperand &consumer) { 2287 if (auto producerCollapseOp = 2288 dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) { 2289 return !isUnitDimExpansionOnly(producerCollapseOp); 2290 } 2291 if (auto consumerExpandOp = 2292 dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) { 2293 return !isUnitDimExpansionOnly(consumerExpandOp); 2294 } 2295 return true; 2296 } 2297 2298 namespace { 2299 2300 /// Pass that fuses generic ops on tensors. Used only for testing. 2301 struct LinalgElementwiseOpFusionPass 2302 : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> { 2303 void runOnOperation() override { 2304 Operation *op = getOperation(); 2305 RewritePatternSet patterns(op->getContext()); 2306 ControlElementwiseOpsFusionFn allowFoldingFn = 2307 [](const OpResult &producer, const OpOperand &consumer) { 2308 return true; 2309 }; 2310 populateElementwiseOpsFusionPatterns( 2311 patterns, 2312 LinalgElementwiseFusionOptions().setControlFoldingReshapes( 2313 allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape)); 2314 2315 // Use TopDownTraversal for compile time reasons 2316 GreedyRewriteConfig grc; 2317 grc.useTopDownTraversal = true; 2318 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), 2319 grc); 2320 } 2321 }; 2322 2323 /// Pass to test folding of reshape ops with generic ops by linearization. 2324 struct FoldReshapeOpsByLinearizationPass 2325 : public LinalgFoldReshapeOpsByLinearizationBase< 2326 FoldReshapeOpsByLinearizationPass> { 2327 void runOnOperation() override { 2328 Operation *op = getOperation(); 2329 RewritePatternSet patterns(op->getContext()); 2330 populateFoldReshapeOpsByLinearizationPatterns(patterns); 2331 if (allowFoldingUnitDimReshapes) { 2332 populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns); 2333 } 2334 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); 2335 } 2336 }; 2337 2338 } // namespace 2339 2340 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() { 2341 return std::make_unique<LinalgElementwiseOpFusionPass>(); 2342 } 2343 2344 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() { 2345 return std::make_unique<FoldReshapeOpsByLinearizationPass>(); 2346 } 2347