1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion on tensors operations pass. 10 // 11 //===----------------------------------------------------------------------===// 12 #include <utility> 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/IR/AffineExpr.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Support/LLVM.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 using namespace mlir; 28 using namespace mlir::linalg; 29 30 //===---------------------------------------------------------------------===// 31 // Methods and patterns that fuse elementwise `linalg.generic` operations. 32 //===---------------------------------------------------------------------===// 33 34 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 35 /// the `producer` to use in the fused operation given the indexing map of the 36 /// result of the producer in the consumer. 37 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 38 OpOperand *producerOpOperand, AffineMap producerResultIndexMap, 39 AffineMap fusedConsumerArgIndexMap) { 40 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 41 // from consumer loop -> consumer arg tensor index/producer result tensor 42 // index. The fused loop is same as the consumer loop. For each producer arg 43 // the indexing map to be computed is a map from consumer loop -> producer 44 // arg tensor index. 45 // producerResultIndexMap is a map from producer loop -> tensor index. 46 // Compute the inverse to get map from tensor index -> producer loop. 47 // The inverse is a map from producer result tensor index -> producer loop. 48 AffineMap invProducerResultIndexMap = 49 inversePermutation(producerResultIndexMap); 50 assert(invProducerResultIndexMap && 51 "expected producer result indexig map to be invertible"); 52 53 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner()); 54 // argMap is a map from producer loop -> producer arg tensor index. 55 AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); 56 57 // Compose argMap with invProducerResultIndexMap to get a map from 58 // producer result tensor index -> producer arg tensor index. 59 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 60 61 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 62 // consumer loop/ fused loop -> producer arg tensor index. 63 return t1.compose(fusedConsumerArgIndexMap); 64 } 65 66 /// Conditions for elementwise fusion of generic operations. 67 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, 68 OpOperand *consumerOpOperand) { 69 // Producer and consumer must have tensor semantics. 70 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 71 return false; 72 73 // Verify that 74 // - the producer has all "parallel" iterator type. 75 if (producer.getNumParallelLoops() != producer.getNumLoops()) 76 return false; 77 78 // Only allow fusing the producer of an input operand for now. 79 // TODO: allow fusing the producer of an output operand. 80 if (!consumer.isInputTensor(consumerOpOperand)) 81 return false; 82 83 // Get the consumer index map. The number of results of the consumer index 84 // map must match the number of loops of the producer. 85 AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); 86 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 87 return false; 88 89 // Currently support only operations with single result. 90 if (producer.getNumOutputs() != 1) 91 return false; 92 93 // Finally the index_map for the result must be invertible. For now just 94 // verify it is a permutation. 95 AffineMap producerResultIndexMap = 96 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 97 if (!producerResultIndexMap.isPermutation()) 98 return false; 99 100 // Ensure that the fusion does not remove size information required to 101 // get the loop bounds. For non-reduction generics, this is trivially the 102 // case due to the output operand. For reductions, we need to check that after 103 // the fusion, each loop dimension has at least one input that defines it. 104 if ((consumer.getNumReductionLoops())) { 105 BitVector coveredDims(consumer.getNumLoops(), false); 106 107 auto addToCoveredDims = [&](AffineMap map) { 108 for (auto result : map.getResults()) 109 if (auto dimExpr = result.dyn_cast<AffineDimExpr>()) 110 coveredDims[dimExpr.getPosition()] = true; 111 }; 112 113 for (auto pair : 114 llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) { 115 Value operand = std::get<0>(pair); 116 if (operand == consumerOpOperand->get()) 117 continue; 118 AffineMap operandMap = std::get<1>(pair); 119 addToCoveredDims(operandMap); 120 } 121 122 for (OpOperand *operand : producer.getInputOperands()) { 123 AffineMap newIndexingMap = 124 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 125 operand, producerResultIndexMap, consumerIndexMap); 126 addToCoveredDims(newIndexingMap); 127 } 128 if (!coveredDims.all()) 129 return false; 130 } 131 132 return true; 133 } 134 135 /// Generate the region of the fused tensor operation. The region of the fused 136 /// op must be empty. 137 static void 138 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, 139 AffineMap consumerToProducerLoopsMap, 140 OpOperand *consumerOpOperand, 141 unsigned nloops) { 142 auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp()); 143 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 144 // Build the region of the fused op. 145 Block &producerBlock = producer->getRegion(0).front(); 146 Block &consumerBlock = consumer->getRegion(0).front(); 147 Block *fusedBlock = new Block(); 148 fusedOp.region().push_back(fusedBlock); 149 BlockAndValueMapping mapper; 150 OpBuilder::InsertionGuard guard(rewriter); 151 rewriter.setInsertionPointToStart(fusedBlock); 152 153 // 2. Add an index operation for every fused loop dimension and use the 154 // `consumerToProducerLoopsMap` to map the producer indices. 155 if (producer.hasIndexSemantics()) { 156 // Add an index operation for every fused loop dimension. 157 unsigned numFusedOpLoops = 158 std::max(producer.getNumLoops(), consumer.getNumLoops()); 159 SmallVector<Value> fusedIndices; 160 fusedIndices.reserve(numFusedOpLoops); 161 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops), 162 std::back_inserter(fusedIndices), [&](uint64_t dim) { 163 return rewriter.create<IndexOp>(producer.getLoc(), dim); 164 }); 165 for (IndexOp indexOp : 166 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) { 167 Value newIndex = rewriter.create<mlir::AffineApplyOp>( 168 producer.getLoc(), 169 consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices); 170 mapper.map(indexOp.getResult(), newIndex); 171 } 172 } 173 // TODO: allow fusing the producer of an output operand. 174 assert(consumer.isInputTensor(consumerOpOperand) && 175 "expected producer of input operand"); 176 // 3. Consumer input operands up to consumerIdx (exclusive). 177 for (BlockArgument bbArg : consumerBlock.getArguments().take_front( 178 consumerOpOperand->getOperandNumber())) // input assumption. 179 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 180 181 // Replacing consumerIdx requires getting the cloned, yielded, value from 182 // the (cloned) producer block. This happens in step 9. 183 184 // 4. Splice in producer's input operands. 185 for (BlockArgument bbArg : 186 producerBlock.getArguments().take_front(producer.getNumInputs())) 187 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 188 189 // 4.b. Producer output operand/map that is fused needs to be mapped to the 190 // producer bbArg if it is an "initTensor" (i.e. its value is actually read). 191 assert(producer->getNumResults() == 1 && "expected single result producer"); 192 if (producer.isInitTensor(producer.getOutputOperand(0))) { 193 BlockArgument bbArg = producerBlock.getArguments() 194 .drop_front(producer.getNumInputs()) 195 // TODO: bbArg index of 196 .front(); 197 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 198 } 199 // 5. Remaining consumer's input operands (drop past index `consumerIdx`). 200 for (BlockArgument bbArg : 201 consumerBlock.getArguments() 202 .take_front(consumer.getNumInputs()) 203 .drop_front(consumerOpOperand->getOperandNumber() + 1)) 204 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 205 // 6. All of consumer's output operands. 206 for (BlockArgument bbArg : 207 consumerBlock.getArguments().take_back(consumer.getNumOutputs())) 208 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 209 // 7. All of producer's output operands except the one fused. 210 // TODO: allow fusion of multi-result producers. 211 assert(producer->getNumResults() == 1 && "expected single result producer"); 212 213 // 8. Clone all producer operations except for the yield and index operations 214 // to the fused operation. 215 for (auto &op : producerBlock.without_terminator()) { 216 if (!isa<IndexOp>(op)) 217 rewriter.clone(op, mapper); 218 } 219 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just 220 // forward the yield operand. 221 auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator()); 222 // TODO: allow fusion of multi-result producers. 223 assert(producer->getNumResults() == 1 && "expected single result producer"); 224 unsigned producerResultNumber = 0; 225 Value replacement = 226 mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); 227 // Sanity checks, if replacement is not already in the mapper then it must be 228 // produced outside. 229 if (replacement == yieldOp.getOperand(producerResultNumber)) { 230 if (auto bb = replacement.dyn_cast<BlockArgument>()) 231 assert(bb.getOwner() != &producerBlock && 232 "yielded block argument must have been mapped"); 233 else 234 assert(!producer->isAncestor(replacement.getDefiningOp()) && 235 "yielded value must have been mapped"); 236 } 237 mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), 238 replacement); 239 // 10. Clone operations from the consumer to the fused op. 240 for (auto &op : consumerBlock.getOperations()) 241 rewriter.clone(op, mapper); 242 243 // Sanity checks. 244 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && 245 "Ill-formed GenericOp region"); 246 } 247 248 static Optional<SmallVector<Value>> 249 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, 250 const ControlElementwiseOpsFusionFn &controlFn, 251 PatternRewriter &rewriter) { 252 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 253 if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || 254 !controlFn(producer->getResult(0), *consumerOpOperand)) 255 return llvm::None; 256 257 // TODO: allow fusing the producer of an output operand. 258 assert(consumer.isInputTensor(consumerOpOperand) && 259 "expected producer of input operand"); 260 261 // Compute the fused operands list and indexing maps. 262 SmallVector<Value> fusedOperands; 263 SmallVector<AffineMap> fusedIndexMaps; 264 fusedOperands.reserve(producer->getNumOperands() + 265 consumer->getNumOperands()); 266 fusedIndexMaps.reserve(producer->getNumOperands() + 267 consumer->getNumOperands()); 268 // In the following, numbering matches that of `generateFusedTensorOpRegion`. 269 // 3. Consumer input operands/maps up to consumerIdx (exclusive). 270 SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands(); 271 SmallVector<OpOperand *>::iterator it = 272 llvm::find(consumerInputs, consumerOpOperand); 273 assert(it != consumerInputs.end() && "expected to find the consumer operand"); 274 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { 275 fusedOperands.push_back(opOperand->get()); 276 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 277 } 278 // 4. Splice in producer's input operands/maps. 279 assert(producer->getNumResults() == 1 && "expected single result producer"); 280 AffineMap producerResultIndexMap = 281 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 282 for (OpOperand *opOperand : producer.getInputOperands()) { 283 fusedOperands.push_back(opOperand->get()); 284 // Compute indexing maps for the producer args in the fused operation. 285 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 286 opOperand, producerResultIndexMap, 287 consumer.getTiedIndexingMap(consumerOpOperand)); 288 fusedIndexMaps.push_back(map); 289 } 290 // 4.b. Producer output operand/map that is fused needs to be passed if it is 291 // an "initTensor" (i.e. its value is actually read). 292 assert(producer->getNumResults() == 1 && "expected single result producer"); 293 if (producer.isInitTensor(producer.getOutputOperand(0))) { 294 fusedOperands.push_back(producer.getOutputOperand(0)->get()); 295 // Compute indexing maps for the producer args in the fused operation. 296 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 297 producer.getOutputOperand(0), producerResultIndexMap, 298 consumer.getTiedIndexingMap(consumerOpOperand)); 299 fusedIndexMaps.push_back(map); 300 } 301 // 5. Remaining consumer's input operands/maps (drop past index 302 // `consumerIdx`). 303 for (OpOperand *opOperand : 304 llvm::make_range(std::next(it), consumerInputs.end())) { 305 fusedOperands.push_back(opOperand->get()); 306 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 307 } 308 // 6. All of consumer's output operands (skip operands: added by the builder). 309 for (OpOperand *opOperand : consumer.getOutputOperands()) 310 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 311 // 7. All of producer's output operands/maps except the one fused. 312 // TODO: allow fusion of multi-result producers. 313 assert(producer->getNumResults() == 1 && "expected single result producer"); 314 315 // Generate the fused op. 316 SmallVector<Value> consumerOutputs = consumer.getOutputOperands(); 317 auto fusedOp = rewriter.create<GenericOp>( 318 consumer.getLoc(), consumer->getResultTypes(), 319 /*inputs=*/fusedOperands, 320 // TODO: handle outputs. 321 consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 322 consumer.iterator_types(), 323 /*doc=*/nullptr, 324 /*library_call=*/nullptr); 325 if (!fusedOp.getShapesToLoopsMap()) { 326 // Fused op has invalid indexing maps. Typically this means something is off 327 // in the input, but going ahead here would result in verification errors. 328 // So cleanup and abort. 329 rewriter.eraseOp(fusedOp); 330 return llvm::None; 331 } 332 333 // Construct an AffineMap from consumer loops to producer loops. 334 // consumer loop -> tensor index 335 AffineMap consumerResultIndexMap = 336 consumer.getTiedIndexingMap(consumerOpOperand); 337 // tensor index -> producer loop 338 AffineMap invProducerResultIndexMap = 339 inversePermutation(producerResultIndexMap); 340 assert(invProducerResultIndexMap && 341 "expected producer result indexig map to be invertible"); 342 // consumer loop -> producer loop 343 AffineMap consumerToProducerLoopsMap = 344 invProducerResultIndexMap.compose(consumerResultIndexMap); 345 346 generateFusedElementwiseOpRegion(rewriter, fusedOp, 347 consumerToProducerLoopsMap, 348 consumerOpOperand, consumer.getNumLoops()); 349 return SmallVector<Value>(fusedOp->getResults()); 350 } 351 352 static Optional<SmallVector<Value>> 353 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, 354 GenericOp producer, 355 const ControlElementwiseOpsFusionFn &controlFn) { 356 if (producer->getNumResults() != 1) 357 return llvm::None; 358 359 return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, 360 rewriter); 361 } 362 363 namespace { 364 /// Patterns to fuse a generic op, with the producer of its operands. 365 class FuseElementwiseOps : public OpRewritePattern<GenericOp> { 366 public: 367 FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, 368 PatternBenefit benefit = 1) 369 : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {} 370 371 LogicalResult matchAndRewrite(GenericOp genericOp, 372 PatternRewriter &rewriter) const override { 373 // Find the first operand that is defined by another generic op on tensors. 374 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 375 auto producer = 376 dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp()); 377 if (!producer || !producer.hasTensorSemantics()) 378 continue; 379 Optional<SmallVector<Value>> fusedOpResults = 380 fuseElementwiseOps(rewriter, opOperand, producer, controlFn); 381 if (fusedOpResults) { 382 rewriter.replaceOp(genericOp, *fusedOpResults); 383 return success(); 384 } 385 } 386 return failure(); 387 } 388 389 private: 390 ControlElementwiseOpsFusionFn controlFn; 391 }; 392 } // namespace 393 394 //===---------------------------------------------------------------------===// 395 // Methods and patterns that fuse reshape ops with elementwise operations by 396 // linearization of indexing maps. 397 //===---------------------------------------------------------------------===// 398 399 // TODO(ravishankarm): These patterns need to be deprecated. The indexing maps 400 // these produce in the general case are detrimental to transformations. 401 // They are useful now only in the limited case of unit-dimension folding. 402 // Remove these in favor of more general folding by dimension contraction. 403 404 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` 405 /// provided, given the shape of the source tensor that corresponds to the 406 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions 407 /// are "row-major" ordered logically. 408 /// 409 /// For example: 410 /// 411 /// %0 = op ... : tensor<?x?x4x5xf32> 412 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` 413 /// 414 /// and reshape: 415 /// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] : 416 /// tensor<?x?x4x5xf32> into tensor<?x?xf32> 417 /// 418 /// would be rewritten into: 419 /// %0 = op ... : tensor<?x?x4x5xf32> 420 /// with output index_map 421 /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` 422 template <typename TensorReshapeOp> 423 static AffineMap linearizeCollapsedDims(AffineMap sourceMap, 424 TensorReshapeOp reshapeOp) { 425 constexpr bool isExpanding = 426 std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value; 427 ArrayRef<int64_t> sourceShape = 428 (isExpanding ? reshapeOp.getResultType().getShape() 429 : reshapeOp.getSrcType().getShape()); 430 SmallVector<AffineExpr> resultExprs; 431 ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults(); 432 MLIRContext *context = sourceMap.getContext(); 433 434 // Compute the result exprs based on the reassociation maps. 435 for (auto &indices : reshapeOp.getReassociationIndices()) { 436 // Assume that they are in-order and contiguous (already checked in 437 // verifier). 438 assert(!indices.empty()); 439 SmallVector<int64_t> sizes; 440 SmallVector<AffineExpr> dimExprs; 441 for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()), 442 sourceExprs.slice(indices[0], indices.size()))) { 443 if (std::get<0>(en) == 1) 444 continue; 445 sizes.push_back(std::get<0>(en)); 446 dimExprs.push_back(std::get<1>(en)); 447 } 448 AffineExpr linearizedExpr = 449 makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); 450 resultExprs.push_back(linearizedExpr); 451 } 452 // The new affine map cannot drop unused dimension but some new symbols may 453 // have been added. Create a map with at least as many dimensions/symbols as 454 // the original affine map. 455 int64_t maxDim = -1; 456 int64_t maxSym = -1; 457 getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym); 458 unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims()); 459 unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols()); 460 return AffineMap::get(numDims, numSyms, resultExprs, context); 461 } 462 463 // tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a 464 // producer). Fusing when operand has higher rank will require use of mods and 465 // divs in the indexing maps of the fused op which would make it non-invertible. 466 static bool isTensorReshapeOpFoldableByLinearization( 467 tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) { 468 if (!asProducer) 469 return false; 470 return useIndexMap.isPermutation(); 471 } 472 473 // tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a 474 // consumer). 475 static bool 476 isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp, 477 AffineMap useIndexMap, 478 bool asProducer) { 479 if (asProducer) 480 return false; 481 return useIndexMap.isPermutation(); 482 } 483 484 /// Check if the reshape operation is only expansion into/collapsing of 485 /// unit-dimension. 486 template <typename TensorReshapeOp> 487 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) { 488 constexpr bool isExpanding = 489 std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value; 490 ArrayRef<int64_t> expandedShape = 491 (isExpanding ? reshapeOp.getResultType().getShape() 492 : reshapeOp.getSrcType().getShape()); 493 for (auto &indices : reshapeOp.getReassociationIndices()) { 494 unsigned numUnitDims = 0; 495 for (int64_t position : indices) 496 if (expandedShape[position] == 1) 497 numUnitDims++; 498 if (numUnitDims != indices.size() - 1) 499 return false; 500 } 501 return true; 502 } 503 504 namespace { 505 /// Pattern to fold tensor_expand_shape op with its consumer by using the source 506 /// of the reshape op as the operand in the consumer (instead of the result of 507 /// the tensor_collapse_shape). The corresponding index map in the consumer 508 /// needs to be modified to linearize the folded dimension. 509 /// 510 /// For example, 511 /// 512 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 513 /// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] 514 /// tensor<?x?x?xf32> into tensor<?x?x4x?xf32> 515 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } 516 /// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ... 517 /// -> tensor<?x?x4x?xf32> 518 /// 519 /// can be folded into 520 /// 521 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> 522 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 523 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } 524 /// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ... 525 /// -> tensor<?x?x4x?xf32> 526 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 527 struct FoldProducerReshapeOpByLinearization 528 : public OpRewritePattern<GenericOp> { 529 using OpRewritePattern<GenericOp>::OpRewritePattern; 530 531 LogicalResult matchAndRewrite(GenericOp genericOp, 532 PatternRewriter &rewriter) const override { 533 if (!genericOp.hasTensorSemantics()) 534 return failure(); 535 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 536 for (const auto &en : llvm::enumerate(inputOperands)) { 537 auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>(); 538 if (!reshapeOp) 539 continue; 540 541 if (!isTensorReshapeOpFoldableByLinearization( 542 reshapeOp, genericOp.getTiedIndexingMap(en.value()), 543 /*asProducer =*/true) || 544 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 545 continue; 546 547 // Compute the fused operands list, 548 SmallVector<Value> fusedOperands = genericOp.getInputOperands(); 549 fusedOperands[en.index()] = reshapeOp.src(); 550 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 551 llvm::append_range(fusedOperands, outputOperands); 552 553 // Compute indexing_maps for the fused operation. The indexing_maps for 554 // the operands of the consumers that arent fused are the same. 555 SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps(); 556 557 // Compute the indexing map to use for the result of the producer. 558 AffineMap modifiedMap = 559 linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp); 560 // The modified map cannot have symbols. 561 if (modifiedMap.getNumSymbols()) 562 return failure(); 563 for (AffineExpr expr : modifiedMap.getResults()) { 564 if (!expr.isPureAffine()) 565 return failure(); 566 } 567 fusedIndexMaps[en.index()] = modifiedMap; 568 569 // Further check that the resulting index maps can be fused and 570 // inverted. Without this the resultant op is not legal. 571 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 572 return rewriter.notifyMatchFailure( 573 genericOp, "fused op loop bound computation failed"); 574 } 575 576 rewriter.startRootUpdate(genericOp); 577 genericOp->setOperands(fusedOperands); 578 genericOp.indexing_mapsAttr( 579 rewriter.getAffineMapArrayAttr(fusedIndexMaps)); 580 rewriter.finalizeRootUpdate(genericOp); 581 return success(); 582 } 583 return failure(); 584 } 585 }; 586 587 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its 588 /// producer. The corresponding index map in the consumer needs to be modified 589 /// to linearize the folded dimension. 590 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 591 struct FoldConsumerReshapeOpByLinearization 592 : public OpRewritePattern<TensorReshapeOp> { 593 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 594 595 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 596 PatternRewriter &rewriter) const override { 597 GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>(); 598 if (!producer || !producer.hasTensorSemantics() || 599 producer.getNumOutputs() != 1 || 600 !isTensorReshapeOpFoldableByLinearization( 601 reshapeOp, 602 producer.getTiedIndexingMap(producer.getOutputOperand(0)), 603 /*asProducer =*/false) || 604 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 605 return failure(); 606 // The indexing_maps for the operands of the fused operation are same as 607 // those for the operands of the producer. 608 SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps(); 609 610 // Compute the indexing map to use for the operand of the producer. 611 AffineMap modifiedMap = linearizeCollapsedDims( 612 producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp); 613 for (AffineExpr expr : modifiedMap.getResults()) { 614 if (!expr.isPureAffine()) { 615 return rewriter.notifyMatchFailure( 616 producer, "fused op indexing map is not affine"); 617 } 618 } 619 fusedIndexMaps.back() = modifiedMap; 620 621 // Further check that the resulting index maps can be fused and 622 // inverted. Without this the resultant op is not legal. 623 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 624 return rewriter.notifyMatchFailure( 625 producer, "fused op loop bound computation failed"); 626 } 627 628 Location loc = producer.getLoc(); 629 SmallVector<Value> inputOperands = producer.getInputOperands(); 630 Value output = rewriter.create<TensorReshapeOp>( 631 loc, producer.getOutputOperand(0)->get(), 632 reshapeOp.getReassociationExprs()); 633 auto fusedOp = rewriter.create<GenericOp>( 634 loc, reshapeOp.getResultType(), 635 /*inputs=*/inputOperands, 636 // TODO: handle outputs. 637 /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 638 producer.iterator_types(), 639 /*doc=*/nullptr, 640 /*library_call=*/nullptr); 641 auto &fusedRegion = fusedOp->getRegion(0); 642 rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion, 643 fusedRegion.begin()); 644 rewriter.replaceOp(reshapeOp, fusedOp->getResults()); 645 return success(); 646 } 647 }; 648 } // namespace 649 650 //===---------------------------------------------------------------------===// 651 // Methods and patterns that fuse reshape ops with elementwise operations by 652 // expanding the dimensionality of the elementwise operations. 653 //===---------------------------------------------------------------------===// 654 655 /// Conditions for folding a generic operation with a reshape op by expanding 656 /// the iteration space dimensionality for tensor operations. These are 657 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the 658 /// following fusion pattern. 659 /// 660 /// Consider 661 /// 662 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>) 663 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 664 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 665 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] 666 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] 667 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 668 /// 669 /// The reshape can be folded into the `genericOp` if its loop dimensionality 670 /// is increased to match the result (operand) of the tensor_expand_shape. 671 /// The indexing_map of the fused tensor in the `genericOp` and the 672 /// reassociation map helps compute the indexing maps of the modified op. 673 /// For the above example, based on the reassociation map it 674 /// can be concluded that 675 /// 676 /// - The loop used to access the first dimension of the fused tensor is split 677 /// into two. 678 /// - The loop used to access the second dimension of the fused tensor is kept 679 /// as is. 680 /// - The loop used to access the third dimension of the fused tensor is split 681 /// into three. 682 /// 683 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified 684 /// op, then 685 /// 686 /// d0 -> e0, e1 687 /// d1 -> e2, e3, e4 688 /// d2 -> e5 689 /// 690 /// substituting this, the generic op can be rewritten as 691 /// 692 /// %d = linalg.generic ins(%0, %1 : ) 693 /// indexing_maps = 694 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, 695 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, 696 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] 697 /// 698 /// Since operands to the linalg generic are now 5D, reshapes can be introduced 699 /// to make it consistent 700 /// 701 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]] 702 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 703 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]] 704 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 705 /// 706 /// The added reshapes are again expanding patterns, so they will get fused 707 /// with its producers if possible. 708 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, 709 OpOperand *fusableOpOperand) { 710 // Is fusable only if: 711 // - All the indexing maps for operands and results are projected 712 // permutations. 713 // - The fused tensor is not a scalar. 714 // - All the loops are parallel loops. 715 return genericOp.hasTensorSemantics() && 716 llvm::all_of(genericOp.indexing_maps().getValue(), 717 [](Attribute attr) { 718 return attr.cast<AffineMapAttr>() 719 .getValue() 720 .isProjectedPermutation(); 721 }) && 722 genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && 723 llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { 724 return attr.cast<StringAttr>().getValue() == 725 getParallelIteratorTypeName(); 726 }); 727 } 728 729 namespace { 730 /// Information needed to expand a generic operation to fold the reshape with 731 /// it. 732 class ExpansionInfo { 733 public: 734 // Computes the mapping from original dimensions of the op to the dimensions 735 // of the expanded op given the `indexingMap` of the fused operand/result of 736 // the generic op, the `reassocationMaps` of the reshape op and the shape of 737 // the expanded op. 738 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, 739 ArrayRef<AffineMap> reassociationMaps, 740 ArrayRef<int64_t> expandedShape, 741 ArrayRef<int64_t> collapsedShape, 742 PatternRewriter &rewriter); 743 unsigned getOrigOpNumDims() const { return reassociation.size(); } 744 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } 745 ReassociationIndicesRef getExpandedDims(unsigned i) const { 746 return reassociation[i]; 747 } 748 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { 749 return expandedShapeMap[i]; 750 } 751 ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; } 752 753 private: 754 /// Reassociation from the dimensions in the original operation to the 755 /// dimension of the expanded operation. 756 SmallVector<ReassociationIndices> reassociation; 757 /// Mapping from extent of loops in the original operation, to the extent of 758 /// loops in the expanded operation. 759 SmallVector<SmallVector<int64_t>> expandedShapeMap; 760 /// Extent of the loop in the original operation. 761 SmallVector<int64_t> originalLoopExtent; 762 unsigned expandedOpNumDims; 763 }; 764 } // namespace 765 766 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, 767 OpOperand *fusableOpOperand, 768 ArrayRef<AffineMap> reassociationMaps, 769 ArrayRef<int64_t> expandedShape, 770 ArrayRef<int64_t> collapsedShape, 771 PatternRewriter &rewriter) { 772 if (reassociationMaps.empty()) 773 return failure(); 774 AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); 775 776 Optional<SmallVector<int64_t, 4>> originalLoopRange = 777 linalgOp.getStaticLoopRanges(); 778 if (!originalLoopRange) 779 return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range"); 780 originalLoopExtent.assign(originalLoopRange->begin(), 781 originalLoopRange->end()); 782 783 reassociation.clear(); 784 expandedShapeMap.clear(); 785 // Compute the number of dimension in the expanded op that correspond to each 786 // dimension of the original op. 787 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1); 788 expandedShapeMap.resize(fusedIndexMap.getNumDims()); 789 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { 790 unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition(); 791 AffineMap foldedDims = reassociationMaps[resultExpr.index()]; 792 numExpandedDims[pos] = foldedDims.getNumResults(); 793 ArrayRef<int64_t> shape = 794 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); 795 expandedShapeMap[pos].assign(shape.begin(), shape.end()); 796 } 797 // The remaining dimensions remain the same. 798 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims())) 799 if (expandedShapeMap[i].empty()) 800 expandedShapeMap[i] = {originalLoopExtent[i]}; 801 802 // Compute reassociation map from the original op to the expanded op. 803 unsigned sum = 0; 804 reassociation.reserve(fusedIndexMap.getNumDims()); 805 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) { 806 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value()); 807 reassociation.emplace_back(seq.begin(), seq.end()); 808 sum += numFoldedDim.value(); 809 } 810 expandedOpNumDims = sum; 811 return success(); 812 } 813 814 /// Epanding the body of a linalg operation requires adaptations of the accessed 815 /// loop indices. Specifically, access of indices in the original operation need 816 /// to be replaced with linearizations of indices in the expanded op. That 817 /// requires the shape of the expanded dimensions to be static (at least all but 818 /// the most significant). For now check that these are all statically sized. 819 /// Note that this could be extended to handle dynamic case, but the 820 /// implementation below uses `affine.apply` which seems to have issues when the 821 /// shapes are not static. 822 static LogicalResult isGenericOpExpandable(GenericOp genericOp, 823 const ExpansionInfo &expansionInfo, 824 PatternRewriter &rewriter) { 825 if (!genericOp.hasIndexSemantics()) 826 return success(); 827 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 828 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 829 if (expandedShape.size() == 1) 830 continue; 831 for (int64_t shape : expandedShape.drop_front()) { 832 if (ShapedType::isDynamic(shape)) { 833 return rewriter.notifyMatchFailure( 834 genericOp, "cannot expand due to index semantics and dynamic dims"); 835 } 836 } 837 } 838 return success(); 839 } 840 841 /// Return the indexing map to use in the expanded op for a given the 842 /// `indexingMap` of the original operation. 843 static AffineMap 844 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, 845 const ExpansionInfo &expansionInfo) { 846 SmallVector<AffineExpr> newExprs; 847 for (AffineExpr expr : indexingMap.getResults()) { 848 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 849 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>( 850 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { 851 return builder.getAffineDimExpr(static_cast<unsigned>(v)); 852 })); 853 newExprs.append(expandedExprs.begin(), expandedExprs.end()); 854 } 855 return AffineMap::get(expansionInfo.getExpandedOpNumDims(), 856 indexingMap.getNumSymbols(), newExprs, 857 builder.getContext()); 858 } 859 860 /// Return the type of the operand/result to use in the expanded op given the 861 /// type in the original op. 862 static RankedTensorType getExpandedType(RankedTensorType originalType, 863 AffineMap indexingMap, 864 const ExpansionInfo &expansionInfo) { 865 SmallVector<int64_t> expandedShape; 866 for (AffineExpr expr : indexingMap.getResults()) { 867 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 868 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); 869 expandedShape.append(dimExpansion.begin(), dimExpansion.end()); 870 } 871 return RankedTensorType::get(expandedShape, originalType.getElementType()); 872 } 873 874 /// Returns the reassociation maps to use in the `tensor.expand_shape` 875 /// operation to convert the operands of the original operation to operands of 876 /// the expanded operation. The same method is used to compute the 877 /// `tensor.collapse_shape` used to collapse the result of the expanded 878 /// op to get the value that can replace all uses of the results of the original 879 /// op. 880 static SmallVector<ReassociationIndices> 881 getReassociationForExpansion(AffineMap indexingMap, 882 const ExpansionInfo &expansionInfo) { 883 SmallVector<ReassociationIndices> reassociation; 884 unsigned numReshapeDims = 0; 885 for (AffineExpr expr : indexingMap.getResults()) { 886 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 887 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); 888 SmallVector<int64_t, 2> indices = llvm::to_vector<2>( 889 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims)); 890 reassociation.emplace_back(std::move(indices)); 891 numReshapeDims += numExpandedDims; 892 } 893 return reassociation; 894 } 895 896 /// Update the body of an expanded linalg operation having index semantics. The 897 /// indices of the original operation need to be recovered by linearizing the 898 /// indices of the correspoding dimensions of the expanded operation. For now it 899 /// is assumed that the shapes of the expanded operation needed for 900 /// linearization are static. 901 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, 902 Location loc, Region &fusedRegion, 903 const ExpansionInfo &expansionInfo) { 904 // Replace the original indices by the linearization of the expanded indices. 905 for (IndexOp indexOp : 906 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) { 907 ArrayRef<int64_t> expandedDims = 908 expansionInfo.getExpandedDims(indexOp.dim()); 909 assert(!expandedDims.empty() && "expected valid expansion info"); 910 911 // Skip index operations that are not affected by the expansion. 912 if (expandedDims.size() == 1 && 913 expandedDims.front() == (int64_t)indexOp.dim()) 914 continue; 915 916 // Linearize the expanded indices of the original index dimension. 917 OpBuilder::InsertionGuard guard(rewriter); 918 rewriter.setInsertionPointAfter(indexOp); 919 ArrayRef<int64_t> expandedDimsShape = 920 expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front(); 921 SmallVector<Value> expandedIndices; 922 expandedIndices.reserve(expandedDims.size() - 1); 923 llvm::transform( 924 expandedDims.drop_front(), std::back_inserter(expandedIndices), 925 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); 926 Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front()); 927 for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { 928 assert(!ShapedType::isDynamic(std::get<0>(it))); 929 AffineExpr idx, acc; 930 bindDims(rewriter.getContext(), idx, acc); 931 newIndex = rewriter.create<AffineApplyOp>( 932 indexOp.getLoc(), idx + acc * std::get<0>(it), 933 ValueRange{std::get<1>(it), newIndex}); 934 } 935 rewriter.replaceOp(indexOp, newIndex); 936 } 937 } 938 939 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op 940 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes 941 /// that those conditions have been satisfied. 942 static Optional<SmallVector<Value>> 943 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, 944 OpOperand *fusableOpOperand, 945 PatternRewriter &rewriter) { 946 assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) && 947 "preconditions for fuse operation failed"); 948 // Check if reshape is expanding or collapsing. 949 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp); 950 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp); 951 bool isExpanding = (expandingReshapeOp != nullptr); 952 RankedTensorType expandedType = isExpanding 953 ? expandingReshapeOp.getResultType() 954 : collapsingReshapeOp.getSrcType(); 955 RankedTensorType collapsedType = isExpanding 956 ? expandingReshapeOp.getSrcType() 957 : collapsingReshapeOp.getResultType(); 958 959 ExpansionInfo expansionInfo; 960 if (failed(expansionInfo.compute( 961 genericOp, fusableOpOperand, 962 isExpanding ? expandingReshapeOp.getReassociationMaps() 963 : collapsingReshapeOp.getReassociationMaps(), 964 expandedType.getShape(), collapsedType.getShape(), rewriter))) 965 return llvm::None; 966 967 if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) 968 return llvm::None; 969 970 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( 971 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) { 972 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); 973 })); 974 975 SmallVector<Value> expandedOpOperands; 976 expandedOpOperands.reserve(genericOp.getNumInputs()); 977 for (OpOperand *opOperand : genericOp.getInputOperands()) { 978 if (opOperand == fusableOpOperand) { 979 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() 980 : collapsingReshapeOp.src()); 981 continue; 982 } 983 if (genericOp.isInputTensor(opOperand)) { 984 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 985 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 986 RankedTensorType expandedOperandType = 987 getExpandedType(opOperandType, indexingMap, expansionInfo); 988 if (expandedOperandType != opOperand->get().getType()) { 989 // Reshape the operand to get the right type. 990 SmallVector<ReassociationIndices> reassociation = 991 getReassociationForExpansion(indexingMap, expansionInfo); 992 if (failed(reshapeLikeShapesAreCompatible( 993 [&](const Twine &msg) { 994 return rewriter.notifyMatchFailure(genericOp, msg); 995 }, 996 opOperandType.getShape(), expandedOperandType.getShape(), 997 reassociation, 998 /*isExpandingReshape=*/true))) 999 return llvm::None; 1000 expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>( 1001 genericOp.getLoc(), expandedOperandType, opOperand->get(), 1002 reassociation)); 1003 continue; 1004 } 1005 } 1006 expandedOpOperands.push_back(opOperand->get()); 1007 } 1008 1009 Location loc = genericOp.getLoc(); 1010 SmallVector<Value> outputs; 1011 for (OpOperand *opOperand : genericOp.getOutputOperands()) { 1012 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 1013 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 1014 RankedTensorType expandedOutputType = 1015 getExpandedType(opOperandType, indexingMap, expansionInfo); 1016 if (expandedOutputType != opOperand->get().getType()) { 1017 SmallVector<ReassociationIndices> reassociation = 1018 getReassociationForExpansion(indexingMap, expansionInfo); 1019 if (failed(reshapeLikeShapesAreCompatible( 1020 [&](const Twine &msg) { 1021 return rewriter.notifyMatchFailure(genericOp, msg); 1022 }, 1023 opOperandType.getShape(), expandedOutputType.getShape(), 1024 reassociation, 1025 /*isExpandingReshape=*/true))) 1026 return llvm::None; 1027 outputs.push_back(rewriter.create<tensor::ExpandShapeOp>( 1028 genericOp.getLoc(), expandedOutputType, opOperand->get(), 1029 reassociation)); 1030 } 1031 } 1032 1033 // The iterator types of the expanded op are all parallel. 1034 SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(), 1035 getParallelIteratorTypeName()); 1036 1037 TypeRange resultTypes = ValueRange(outputs).getTypes(); 1038 auto fusedOp = 1039 rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes, 1040 /*inputs=*/expandedOpOperands, outputs, 1041 expandedOpIndexingMaps, iteratorTypes); 1042 Region &fusedRegion = fusedOp->getRegion(0); 1043 Region &originalRegion = genericOp->getRegion(0); 1044 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); 1045 1046 // Update the index accesses after the expansion. 1047 updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); 1048 1049 // Reshape the result values to their original shape if this is a collapsing 1050 // reshape folded into its consumer. 1051 SmallVector<Value> resultVals; 1052 for (OpResult opResult : genericOp->getOpResults()) { 1053 int64_t resultNumber = opResult.getResultNumber(); 1054 if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { 1055 SmallVector<ReassociationIndices> reassociation = 1056 getReassociationForExpansion( 1057 genericOp.getTiedIndexingMap( 1058 genericOp.getOutputOperand(resultNumber)), 1059 expansionInfo); 1060 resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>( 1061 genericOp.getLoc(), opResult.getType(), 1062 fusedOp->getResult(resultNumber), reassociation)); 1063 } else { 1064 resultVals.push_back(fusedOp->getResult(resultNumber)); 1065 } 1066 } 1067 // Assuming a single result. 1068 return resultVals; 1069 } 1070 1071 namespace { 1072 1073 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op, 1074 /// when the reshape op is collapsing dimensions. The dimensionality of the loop 1075 /// in the consumer is expanded. 1076 class FoldWithProducerReshapeOpByExpansion 1077 : public OpRewritePattern<GenericOp> { 1078 public: 1079 FoldWithProducerReshapeOpByExpansion( 1080 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1081 PatternBenefit benefit = 1) 1082 : OpRewritePattern<GenericOp>(context, benefit), 1083 controlFoldingReshapes(std::move(foldReshapes)) {} 1084 1085 LogicalResult matchAndRewrite(GenericOp genericOp, 1086 PatternRewriter &rewriter) const override { 1087 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 1088 tensor::CollapseShapeOp reshapeOp = 1089 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>(); 1090 if (!reshapeOp) 1091 continue; 1092 // Fold only if 1093 // - The tensor reshape op is folding. 1094 // - All constraints of fusing with reshape by expansion are met. 1095 if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || 1096 (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) 1097 continue; 1098 1099 Optional<SmallVector<Value>> replacementValues = 1100 fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); 1101 if (!replacementValues) 1102 return failure(); 1103 rewriter.replaceOp(genericOp, replacementValues.getValue()); 1104 return success(); 1105 } 1106 return failure(); 1107 } 1108 1109 private: 1110 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1111 }; 1112 1113 /// Pattern to fold a tensor_expand_shape op with its producer generic op 1114 /// by expanding the dimensionality of the loop in the producer op. 1115 struct FoldReshapeWithGenericOpByExpansion 1116 : public OpRewritePattern<tensor::ExpandShapeOp> { 1117 1118 FoldReshapeWithGenericOpByExpansion( 1119 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1120 PatternBenefit benefit = 1) 1121 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), 1122 controlFoldingReshapes(std::move(foldReshapes)) {} 1123 1124 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, 1125 PatternRewriter &rewriter) const override { 1126 // Fold only if all constraints of fusing with reshape by expansion are met. 1127 GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>(); 1128 if (!producer || producer.getNumOutputs() != 1 || 1129 !isFusableWithReshapeByDimExpansion(producer, 1130 producer.getOutputOperand(0)) || 1131 !controlFoldingReshapes(producer->getResult(0), 1132 reshapeOp->getOpOperand(0))) 1133 return failure(); 1134 Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion( 1135 producer, reshapeOp, producer.getOutputOperand(0), rewriter); 1136 if (!replacementValues) 1137 return failure(); 1138 rewriter.replaceOp(reshapeOp, replacementValues.getValue()); 1139 return success(); 1140 } 1141 1142 private: 1143 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1144 }; 1145 } // namespace 1146 1147 //===---------------------------------------------------------------------===// 1148 // Methods and patterns to convert tensor.expand_shape -> linalg.generic 1149 // into linalg.generic -> tensor.expand_shape, i.e. push the reshape down. 1150 //===---------------------------------------------------------------------===// 1151 1152 static SmallVector<ReassociationIndices> 1153 getReassociationIndices(ArrayRef<AffineMap> maps) { 1154 SmallVector<ReassociationIndices> reassociation; 1155 for (AffineMap map : maps) { 1156 ReassociationIndices indices; 1157 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 1158 unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition(); 1159 indices.push_back(pos); 1160 } 1161 reassociation.push_back(indices); 1162 } 1163 return reassociation; 1164 } 1165 1166 namespace { 1167 /// Pattern to move rank reducing reshape after an elementwise linalg generic 1168 /// op. This is useful to expose more fusion opportunities between named ops and 1169 /// generic ops. This can only be done if there is no broadcast or permuation 1170 /// within the dimensions we need to merge. 1171 /// 1172 /// For example, 1173 /// 1174 /// %0 = tensor.expand_shape %A [[0, 1], [2]] 1175 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 1176 /// %2 = linalg.generic {indexing_maps = [ 1177 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1178 /// affine_map<(d0, d1, d2) -> (d2)>, 1179 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = 1180 /// ["parallel", "parallel", "parallel"]} { 1181 /// } -> tensor<112x112x16xf32> 1182 /// 1183 /// into 1184 /// 1185 /// %2 = linalg.generic {indexing_maps = [ 1186 /// affine_map<(d0, d1) -> (d0, d1)>, 1187 /// affine_map<(d0, d1) -> (d1)>, 1188 /// affine_map<(d0, d1) -> (d0, d1)>], 1189 /// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 1190 /// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) { 1191 /// } -> tensor<12544x16xf32> 1192 /// %3 = tensor.expand_shape %2 [[0, 1], [2]] 1193 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 1194 struct PushExpandingReshape : public OpRewritePattern<GenericOp> { 1195 using OpRewritePattern<GenericOp>::OpRewritePattern; 1196 1197 LogicalResult matchAndRewrite(GenericOp genericOp, 1198 PatternRewriter &rewriter) const override { 1199 // Only apply to elementwise linalg on tensor. 1200 if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() || 1201 genericOp.getNumParallelLoops() != genericOp.getNumLoops()) 1202 return failure(); 1203 // Only support identity output maps. It could be extended to permuations if 1204 // needed. 1205 if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) { 1206 return !genericOp.getTiedIndexingMap(opOperand).isIdentity(); 1207 })) 1208 return failure(); 1209 int64_t destRank = genericOp.getNumParallelLoops(); 1210 SmallVector<Value> newOperands = genericOp.getInputOperands(); 1211 tensor::ExpandShapeOp reshapeFound; 1212 // 1. Look for tensor_expand_shape operands and figure out save the 1213 // dimensions merged. 1214 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 1215 for (const auto &en : llvm::enumerate(inputOperands)) { 1216 auto reshapeOp = 1217 en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>(); 1218 if (!reshapeOp) 1219 continue; 1220 // TODO: We could support non-identity map as long as the merged 1221 // dimensions are still contiguous. 1222 if (!genericOp.getTiedIndexingMap(en.value()).isIdentity()) 1223 continue; 1224 if (reshapeFound) { 1225 // Only support a second reshape op if it has the same reassociate maps. 1226 if (reshapeFound.getReassociationMaps() == 1227 reshapeOp.getReassociationMaps()) 1228 newOperands[en.index()] = reshapeOp.src(); 1229 continue; 1230 } 1231 reshapeFound = reshapeOp; 1232 newOperands[en.index()] = reshapeOp.src(); 1233 } 1234 if (!reshapeFound) 1235 return failure(); 1236 1237 // Calculate the reassociation indices and rassociated reverse map. 1238 SmallVector<ReassociationIndices> reassociation = 1239 getReassociationIndices(reshapeFound.getReassociationMaps()); 1240 SmallVector<unsigned> remap(destRank); 1241 for (auto &indices : llvm::enumerate(reassociation)) { 1242 for (int64_t index : indices.value()) { 1243 remap[index] = indices.index(); 1244 } 1245 } 1246 // 2. Verify that we can merge the dimensions in the linalg and that we 1247 // don't need to create new reshapes operands. Inserting new reshape 1248 // operands would defeat the purpose of the transformation. 1249 for (const auto &en : llvm::enumerate(inputOperands)) { 1250 if (en.value()->get() == newOperands[en.index()]) { 1251 AffineMap map = genericOp.getTiedIndexingMap(en.value()); 1252 for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { 1253 if (reassociation[remap[map.getDimPosition(i)]].size() > 1) 1254 return failure(); 1255 } 1256 } 1257 } 1258 1259 // 3. Calculate the affine map remapping and the reassociation to apply to 1260 // output tensors. 1261 SmallVector<AffineMap> newMaps; 1262 unsigned newRank = reassociation.size(); 1263 for (auto map : genericOp.getIndexingMaps()) { 1264 SmallVector<AffineExpr> newExprs; 1265 for (auto expr : map.getResults()) { 1266 unsigned position = expr.template cast<AffineDimExpr>().getPosition(); 1267 // Skip dimension merged except for the last of the group. 1268 if (reassociation[remap[position]].back() == position) { 1269 newExprs.push_back( 1270 getAffineDimExpr(remap[position], genericOp.getContext())); 1271 } 1272 } 1273 newMaps.push_back( 1274 AffineMap::get(newRank, 0, newExprs, genericOp.getContext())); 1275 } 1276 1277 // 4. Reshape the output tensors. 1278 SmallVector<Value> newOutputs; 1279 SmallVector<Type> newOutputTypes; 1280 for (auto output : genericOp.outputs()) { 1281 auto newOutputType = RankedTensorType::get( 1282 reshapeFound.getSrcType().getShape(), 1283 output.getType().template cast<RankedTensorType>().getElementType()); 1284 Value newOutput = rewriter.create<tensor::CollapseShapeOp>( 1285 genericOp->getLoc(), newOutputType, output, reassociation); 1286 newOutputTypes.push_back(newOutputType); 1287 newOutputs.push_back(newOutput); 1288 } 1289 // 5. Create a new generic op with lowerer rank. 1290 SmallVector<StringRef> iteratorTypes(newRank, 1291 getParallelIteratorTypeName()); 1292 auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes, 1293 newOperands, newOutputs, newMaps, 1294 iteratorTypes); 1295 rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), 1296 newOp.region().begin()); 1297 // 6. Reshape the so that the type matches the uses. 1298 SmallVector<Value> newResults; 1299 for (const auto &result : llvm::enumerate(newOp->getResults())) { 1300 newResults.push_back(rewriter.create<tensor::ExpandShapeOp>( 1301 genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()], 1302 result.value(), reassociation)); 1303 } 1304 rewriter.replaceOp(genericOp, newResults); 1305 return success(); 1306 } 1307 }; 1308 } // namespace 1309 1310 //===---------------------------------------------------------------------===// 1311 // Methods and patterns that fuse constants with linalg.generic operations. 1312 //===---------------------------------------------------------------------===// 1313 1314 namespace { 1315 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not 1316 /// handle cases where the constant is not single-valued. 1317 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> { 1318 public: 1319 FoldScalarOrSplatConstant(MLIRContext *context, 1320 ControlElementwiseOpsFusionFn &fun, 1321 PatternBenefit benefit = 1) 1322 : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {} 1323 1324 LogicalResult matchAndRewrite(GenericOp genericOp, 1325 PatternRewriter &rewriter) const override { 1326 if (!genericOp.hasTensorSemantics()) 1327 return failure(); 1328 for (OpOperand *opOperand : genericOp.getInputOperands()) { 1329 Operation *def = opOperand->get().getDefiningOp(); 1330 Attribute constantAttr; 1331 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { 1332 { 1333 DenseElementsAttr splatAttr; 1334 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) && 1335 splatAttr.isSplat() && 1336 splatAttr.getType().getElementType().isIntOrFloat()) { 1337 constantAttr = splatAttr.getSplatValue<Attribute>(); 1338 return true; 1339 } 1340 } 1341 { 1342 IntegerAttr intAttr; 1343 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) { 1344 constantAttr = intAttr; 1345 return true; 1346 } 1347 } 1348 { 1349 FloatAttr floatAttr; 1350 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) { 1351 constantAttr = floatAttr; 1352 return true; 1353 } 1354 } 1355 return false; 1356 }; 1357 1358 auto resultValue = opOperand->get().dyn_cast<OpResult>(); 1359 if (!def || !resultValue || !isScalarOrSplatConstantOp(def) || 1360 !controlFn(resultValue, *opOperand)) 1361 continue; 1362 1363 // The operands and the indexing_maps of the fused operation the same as 1364 // the operands and indexing_maps of the generic operations with the 1365 // values at the constant index dropped. 1366 SmallVector<AffineMap> fusedIndexMaps; 1367 SmallVector<Value> fusedOperands; 1368 SmallVector<Location> fusedLocs{genericOp.getLoc()}; 1369 fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); 1370 fusedOperands.reserve(genericOp.getNumInputs()); 1371 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); 1372 for (OpOperand *inputOperand : genericOp.getInputOperands()) { 1373 if (inputOperand == opOperand) 1374 continue; 1375 Value inputValue = inputOperand->get(); 1376 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); 1377 fusedOperands.push_back(inputValue); 1378 fusedLocs.push_back(inputValue.getLoc()); 1379 } 1380 for (OpOperand *outputOperand : genericOp.getOutputOperands()) 1381 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); 1382 1383 // Check if the operation shapes to loops map is computable. 1384 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1385 return rewriter.notifyMatchFailure( 1386 genericOp, "fused op loop bound computation failed"); 1387 } 1388 1389 // Create a constant scalar value from the splat constant. 1390 Value scalarConstant = rewriter.create<arith::ConstantOp>( 1391 def->getLoc(), constantAttr, constantAttr.getType()); 1392 1393 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 1394 auto fusedOp = rewriter.create<GenericOp>( 1395 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), 1396 /*inputs=*/fusedOperands, 1397 /*outputs=*/outputOperands, 1398 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1399 genericOp.iterator_types(), 1400 /*doc=*/nullptr, 1401 /*library_call=*/nullptr); 1402 1403 // Map the block argument corresponding to the replaced argument with the 1404 // scalar constant. 1405 Region ®ion = genericOp->getRegion(0); 1406 Block &entryBlock = *region.begin(); 1407 BlockAndValueMapping mapping; 1408 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), 1409 scalarConstant); 1410 Region &fusedRegion = fusedOp->getRegion(0); 1411 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), 1412 mapping); 1413 rewriter.replaceOp(genericOp, fusedOp->getResults()); 1414 return success(); 1415 } 1416 return failure(); 1417 } 1418 1419 private: 1420 ControlElementwiseOpsFusionFn controlFn; 1421 }; 1422 1423 /// Base class for constant folding linalg.generic ops with N inputs, 1 output, 1424 /// and permutation indexing maps. 1425 /// 1426 /// `ConcreteType` should provide methods with signatures 1427 /// 1428 /// ```c++ 1429 /// bool matchIndexingMaps(GenericOp genericOp) const; 1430 /// RegionComputationFn getRegionComputeFn(GenericOp) const; 1431 /// ``` 1432 /// 1433 /// The latter inspects the region and returns the computation inside as a 1434 /// functor. The functor will be invoked with constant elements for all inputs 1435 /// and should return the corresponding computea constant element for output. 1436 template <typename ConcreteType> 1437 class FoldConstantBase : public OpRewritePattern<GenericOp> { 1438 public: 1439 struct APIntOrFloat { 1440 Optional<APInt> apInt; 1441 Optional<APFloat> apFloat; 1442 }; 1443 struct APIntOrFloatArray { 1444 SmallVector<APInt> apInts; 1445 SmallVector<APFloat> apFloats; 1446 }; 1447 using RegionComputationFn = 1448 std::function<APIntOrFloat(const APIntOrFloatArray &)>; 1449 1450 FoldConstantBase(MLIRContext *context, 1451 const ControlElementwiseOpsFusionFn &controlFn, 1452 PatternBenefit benefit = 1) 1453 : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {} 1454 1455 LogicalResult matchAndRewrite(GenericOp genericOp, 1456 PatternRewriter &rewriter) const override { 1457 if (genericOp.hasBufferSemantics()) 1458 return failure(); 1459 1460 // Only support ops generating one output for now. 1461 if (genericOp.getNumOutputs() != 1) 1462 return failure(); 1463 1464 auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>(); 1465 // Require the output types to be static give we are generating constants. 1466 if (!outputType || !outputType.hasStaticShape()) 1467 return failure(); 1468 1469 if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) { 1470 return operand->get().getType().isa<ShapedType>(); 1471 })) 1472 return failure(); 1473 1474 // Make sure all element types are the same. 1475 auto getOperandElementType = [](OpOperand *operand) { 1476 return operand->get().getType().cast<ShapedType>().getElementType(); 1477 }; 1478 if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(), 1479 getOperandElementType))) 1480 return failure(); 1481 1482 // We can only handle the case where we have int/float elements. 1483 auto elementType = outputType.getElementType(); 1484 if (!elementType.isIntOrFloat()) 1485 return failure(); 1486 1487 // Require all indexing maps to be permutations for now. This is common and 1488 // it simplifies input/output access greatly: we can do the data shuffling 1489 // entirely in the compiler, without needing to turn all indices into 1490 // Values, and then do affine apply on them, and then match back the 1491 // constant again. 1492 if (!llvm::all_of(genericOp.getIndexingMaps(), 1493 [](AffineMap map) { return map.isPermutation(); })) 1494 return failure(); 1495 1496 for (OpOperand *operand : genericOp.getOutputOperands()) { 1497 if (genericOp.payloadUsesValueFromOperand(operand)) 1498 return failure(); 1499 } 1500 1501 // Further check the indexing maps are okay for the ConcreteType. 1502 if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp)) 1503 return failure(); 1504 1505 // Defer to the concrete type to check the region and discover the 1506 // computation inside. 1507 RegionComputationFn computeFn = 1508 static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp); 1509 if (!computeFn) 1510 return failure(); 1511 1512 // All inputs should be constants. 1513 int numInputs = genericOp.getNumInputs(); 1514 SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs); 1515 for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) { 1516 if (!matchPattern(operand.value()->get(), 1517 m_Constant(&inputValues[operand.index()]))) 1518 return failure(); 1519 } 1520 1521 // Identified this as a potential candidate for folding. Now check the 1522 // policy to see whether we are allowed to proceed. 1523 for (int i = 0; i < numInputs; ++i) { 1524 OpOperand *consumer = genericOp.getInputOperand(i); 1525 OpResult producer = consumer->get().cast<OpResult>(); 1526 if (!controlFn(producer, *consumer)) 1527 return failure(); 1528 } 1529 1530 auto linalgOp = cast<LinalgOp>(genericOp.getOperation()); 1531 SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes(); 1532 int64_t numElements = outputType.getNumElements(); 1533 1534 // Use APInt/APFloat instead of Attribute here for constructing the output. 1535 // This helps to avoid blowing up compiler memory usage: Attributes would 1536 // unify the following cases but they have lifetime as the MLIRContext. 1537 SmallVector<APInt> intOutputValues; 1538 SmallVector<APFloat> fpOutputValues; 1539 if (elementType.template isa<FloatType>()) 1540 fpOutputValues.resize(numElements, APFloat(0.f)); 1541 else 1542 intOutputValues.resize(numElements); 1543 1544 // Return the constant dim positions from the given permutation map. 1545 auto getDimPositions = [](AffineMap map) { 1546 SmallVector<unsigned> dims; 1547 dims.reserve(map.getNumResults()); 1548 for (AffineExpr result : map.getResults()) { 1549 dims.push_back(result.cast<AffineDimExpr>().getPosition()); 1550 } 1551 return dims; 1552 }; 1553 1554 SmallVector<SmallVector<unsigned>> inputDims; 1555 for (int i = 0; i < numInputs; ++i) 1556 inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i])); 1557 auto outputDims = getDimPositions(genericOp.getIndexingMaps().back()); 1558 auto outputShape = outputType.getShape(); 1559 1560 // Allocate small vectors for index delinearization. Initial values do not 1561 // matter here as they will be overwritten later. 1562 SmallVector<uint64_t> indices(loopBounds.size(), 0); 1563 SmallVector<uint64_t> dstIndices(loopBounds.size(), 0); 1564 SmallVector<SmallVector<uint64_t>> srcIndices( 1565 numInputs, SmallVector<uint64_t>(loopBounds.size(), 0)); 1566 SmallVector<uint64_t> srcLinearIndices(numInputs, 0); 1567 uint64_t dstLinearIndex = 0; 1568 1569 // Allocate spaces for compute function inputs. Initial values do not matter 1570 // here as they will be overwritten later. 1571 APIntOrFloatArray computeFnInputs; 1572 1573 auto inputShapes = llvm::to_vector<4>( 1574 llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { 1575 return operand->get().getType().cast<ShapedType>().getShape(); 1576 })); 1577 1578 // Given a `linearIndex`, remap it to a linear index to access linalg op 1579 // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, 1580 // `srcLinearIndices`, `dstLinearIndex` in place. 1581 auto computeRemappedLinearIndex = [&](int linearIndex) { 1582 int totalCount = linearIndex; 1583 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 1584 indices[dim] = totalCount % loopBounds[dim]; 1585 totalCount /= loopBounds[dim]; 1586 } 1587 1588 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 1589 for (int i = 0; i < numInputs; ++i) 1590 srcIndices[i][dim] = indices[inputDims[i][dim]]; 1591 dstIndices[dim] = indices[outputDims[dim]]; 1592 } 1593 1594 dstLinearIndex = dstIndices.front(); 1595 for (int i = 0; i < numInputs; ++i) 1596 srcLinearIndices[i] = srcIndices[i].front(); 1597 1598 for (int dim = 1; dim < outputType.getRank(); ++dim) { 1599 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; 1600 for (int i = 0; i < numInputs; ++i) 1601 srcLinearIndices[i] = 1602 srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; 1603 } 1604 }; 1605 1606 bool isFloat = elementType.isa<FloatType>(); 1607 if (isFloat) { 1608 SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges; 1609 for (int i = 0; i < numInputs; ++i) 1610 inFpRanges.push_back(inputValues[i].getValues<APFloat>()); 1611 1612 computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); 1613 1614 // Transpose the input constant. Because we don't know its rank in 1615 // advance, we need to loop over the range [0, element count) and 1616 // delinearize the index. 1617 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 1618 computeRemappedLinearIndex(linearIndex); 1619 1620 // Collect constant elements for all inputs at this loop iteration. 1621 for (int i = 0; i < numInputs; ++i) 1622 computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; 1623 1624 // Invoke the computation to get the corresponding constant output 1625 // element. 1626 fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; 1627 } 1628 } else { 1629 SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges; 1630 for (int i = 0; i < numInputs; ++i) 1631 inIntRanges.push_back(inputValues[i].getValues<APInt>()); 1632 1633 computeFnInputs.apInts.resize(numInputs); 1634 1635 // Transpose the input constant. Because we don't know its rank in 1636 // advance, we need to loop over the range [0, element count) and 1637 // delinearize the index. 1638 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 1639 computeRemappedLinearIndex(linearIndex); 1640 1641 // Collect constant elements for all inputs at this loop iteration. 1642 for (int i = 0; i < numInputs; ++i) 1643 computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; 1644 1645 // Invoke the computation to get the corresponding constant output 1646 // element. 1647 intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; 1648 } 1649 } 1650 1651 DenseElementsAttr outputAttr = 1652 isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) 1653 : DenseElementsAttr::get(outputType, intOutputValues); 1654 1655 rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr); 1656 return success(); 1657 } 1658 1659 private: 1660 ControlElementwiseOpsFusionFn controlFn; 1661 }; 1662 1663 // Folds linalg.generic ops that are actually transposes on constant values. 1664 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> { 1665 using FoldConstantBase::FoldConstantBase; 1666 1667 bool matchIndexingMaps(GenericOp genericOp) const { 1668 // We should have one input and one output. 1669 return genericOp.getIndexingMaps().size() == 2; 1670 } 1671 1672 RegionComputationFn getRegionComputeFn(GenericOp genericOp) const { 1673 // Make sure the region only contains a yield op. 1674 Block &body = genericOp.region().front(); 1675 if (!llvm::hasSingleElement(body)) 1676 return nullptr; 1677 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); 1678 if (!yieldOp) 1679 return nullptr; 1680 1681 // The yield op should return the block argument corresponds to the input. 1682 for (Value yieldVal : yieldOp.values()) { 1683 auto yieldArg = yieldVal.dyn_cast<BlockArgument>(); 1684 if (!yieldArg || yieldArg.getOwner() != &body) 1685 return nullptr; 1686 if (yieldArg.getArgNumber() != 0) 1687 return nullptr; 1688 } 1689 1690 // No computation; just return the orginal value. 1691 return [](const APIntOrFloatArray &inputs) { 1692 if (inputs.apFloats.empty()) 1693 return APIntOrFloat{inputs.apInts.front(), llvm::None}; 1694 return APIntOrFloat{llvm::None, inputs.apFloats.front()}; 1695 }; 1696 } 1697 1698 ControlElementwiseOpsFusionFn controlFn; 1699 }; 1700 1701 } // namespace 1702 1703 //===---------------------------------------------------------------------===// 1704 // Miscellaneous patterns that help fusion. 1705 //===---------------------------------------------------------------------===// 1706 1707 namespace { 1708 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if 1709 /// the value of the `outs` operand is not used within the op. This is only 1710 /// implemented for `linalg.generic` operations for now, but should hold for all 1711 /// linalg structured ops. 1712 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> { 1713 using OpRewritePattern<GenericOp>::OpRewritePattern; 1714 1715 LogicalResult matchAndRewrite(GenericOp op, 1716 PatternRewriter &rewriter) const override { 1717 rewriter.startRootUpdate(op); 1718 bool modifiedOutput = false; 1719 Location loc = op.getLoc(); 1720 for (OpOperand *opOperand : op.getOutputOperands()) { 1721 if (!op.payloadUsesValueFromOperand(opOperand)) { 1722 Value operandVal = opOperand->get(); 1723 auto operandType = operandVal.getType().dyn_cast<RankedTensorType>(); 1724 if (!operandType) 1725 continue; 1726 1727 // If outs is already an `init_tensor` operation, nothing to do. 1728 auto definingOp = operandVal.getDefiningOp<InitTensorOp>(); 1729 if (definingOp) 1730 continue; 1731 modifiedOutput = true; 1732 SmallVector<Value> dynamicDims; 1733 for (const auto &dim : llvm::enumerate(operandType.getShape())) { 1734 if (dim.value() != ShapedType::kDynamicSize) 1735 continue; 1736 dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>( 1737 loc, operandVal, dim.index())); 1738 } 1739 Value initTensor = rewriter.create<InitTensorOp>( 1740 loc, dynamicDims, operandType.getShape(), 1741 operandType.getElementType()); 1742 op->setOperand(opOperand->getOperandNumber(), initTensor); 1743 } 1744 } 1745 if (!modifiedOutput) { 1746 rewriter.cancelRootUpdate(op); 1747 return failure(); 1748 } 1749 rewriter.finalizeRootUpdate(op); 1750 return success(); 1751 } 1752 }; 1753 } // namespace 1754 1755 //===---------------------------------------------------------------------===// 1756 // Methods that add patterns descrined in this file to a pattern list. 1757 //===---------------------------------------------------------------------===// 1758 1759 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( 1760 RewritePatternSet &patterns) { 1761 patterns 1762 .add<FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>, 1763 FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>, 1764 FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>, 1765 FoldConsumerReshapeOpByLinearization<false, tensor::ExpandShapeOp>>( 1766 patterns.getContext()); 1767 } 1768 1769 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( 1770 RewritePatternSet &patterns) { 1771 patterns 1772 .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>, 1773 FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>, 1774 FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>, 1775 FoldConsumerReshapeOpByLinearization<true, tensor::ExpandShapeOp>>( 1776 patterns.getContext()); 1777 } 1778 1779 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( 1780 RewritePatternSet &patterns, 1781 const ControlElementwiseOpsFusionFn &controlFoldingReshapes) { 1782 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(), 1783 controlFoldingReshapes); 1784 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), 1785 controlFoldingReshapes); 1786 } 1787 1788 void mlir::linalg::populateElementwiseOpsFusionPatterns( 1789 RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { 1790 auto *context = patterns.getContext(); 1791 patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant, 1792 FoldConstantTranspose>(context, 1793 options.controlElementwiseOpsFusionFn); 1794 patterns.add<RemoveOutsDependency>(context); 1795 populateFoldReshapeOpsByExpansionPatterns(patterns, 1796 options.controlFoldingReshapesFn); 1797 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 1798 GenericOp::getCanonicalizationPatterns(patterns, context); 1799 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); 1800 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); 1801 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns( 1802 patterns); 1803 } 1804 1805 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { 1806 auto *context = patterns.getContext(); 1807 patterns.add<PushExpandingReshape>(context); 1808 } 1809 1810 //===---------------------------------------------------------------------===// 1811 // Passes 1812 //===---------------------------------------------------------------------===// 1813 1814 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, 1815 OpOperand &consumer) { 1816 if (auto producerCollapseOp = 1817 dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) { 1818 return !isUnitDimExpansionOnly(producerCollapseOp); 1819 } 1820 if (auto consumerExpandOp = 1821 dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) { 1822 return !isUnitDimExpansionOnly(consumerExpandOp); 1823 } 1824 return true; 1825 } 1826 1827 namespace { 1828 1829 /// Pass that fuses generic ops on tensors. Used only for testing. 1830 struct LinalgElementwiseOpFusionPass 1831 : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> { 1832 void runOnOperation() override { 1833 Operation *op = getOperation(); 1834 RewritePatternSet patterns(op->getContext()); 1835 ControlElementwiseOpsFusionFn allowFoldingFn = 1836 [](const OpResult &producer, const OpOperand &consumer) { 1837 return true; 1838 }; 1839 populateElementwiseOpsFusionPatterns( 1840 patterns, 1841 LinalgElementwiseFusionOptions().setControlFoldingReshapes( 1842 allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape)); 1843 1844 // Use TopDownTraversal for compile time reasons 1845 GreedyRewriteConfig grc; 1846 grc.useTopDownTraversal = true; 1847 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), 1848 grc); 1849 } 1850 }; 1851 1852 /// Pass to test folding of reshape ops with generic ops by linearization. 1853 struct FoldReshapeOpsByLinearizationPass 1854 : public LinalgFoldReshapeOpsByLinearizationBase< 1855 FoldReshapeOpsByLinearizationPass> { 1856 void runOnOperation() override { 1857 Operation *op = getOperation(); 1858 RewritePatternSet patterns(op->getContext()); 1859 populateFoldReshapeOpsByLinearizationPatterns(patterns); 1860 if (allowFoldingUnitDimReshapes) { 1861 populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns); 1862 } 1863 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); 1864 } 1865 }; 1866 1867 } // namespace 1868 1869 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() { 1870 return std::make_unique<LinalgElementwiseOpFusionPass>(); 1871 } 1872 1873 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() { 1874 return std::make_unique<FoldReshapeOpsByLinearizationPass>(); 1875 } 1876