1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion on tensors operations pass. 10 // 11 //===----------------------------------------------------------------------===// 12 #include <utility> 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/IR/AffineExpr.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Support/LLVM.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 using namespace mlir; 28 using namespace mlir::linalg; 29 30 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 31 /// the `producer` to use in the fused operation given the indexing map of the 32 /// result of the producer in the consumer. 33 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 34 OpOperand *producerOpOperand, AffineMap producerResultIndexMap, 35 AffineMap fusedConsumerArgIndexMap) { 36 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 37 // from consumer loop -> consumer arg tensor index/producer result tensor 38 // index. The fused loop is same as the consumer loop. For each producer arg 39 // the indexing map to be computed is a map from consumer loop -> producer 40 // arg tensor index. 41 // producerResultIndexMap is a map from producer loop -> tensor index. 42 // Compute the inverse to get map from tensor index -> producer loop. 43 // The inverse is a map from producer result tensor index -> producer loop. 44 AffineMap invProducerResultIndexMap = 45 inversePermutation(producerResultIndexMap); 46 assert(invProducerResultIndexMap && 47 "expected producer result indexig map to be invertible"); 48 49 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner()); 50 // argMap is a map from producer loop -> producer arg tensor index. 51 AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); 52 53 // Compose argMap with invProducerResultIndexMap to get a map from 54 // producer result tensor index -> producer arg tensor index. 55 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 56 57 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 58 // consumer loop/ fused loop -> producer arg tensor index. 59 return t1.compose(fusedConsumerArgIndexMap); 60 } 61 62 /// Conditions for elementwise fusion of generic operations. 63 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, 64 OpOperand *consumerOpOperand) { 65 // Producer and consumer must have tensor semantics. 66 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 67 return false; 68 69 // Verify that 70 // - the producer has all "parallel" iterator type. 71 if (producer.getNumParallelLoops() != producer.getNumLoops()) 72 return false; 73 74 // Only allow fusing the producer of an input operand for now. 75 // TODO: allow fusing the producer of an output operand. 76 if (!consumer.isInputTensor(consumerOpOperand)) 77 return false; 78 79 // Get the consumer index map. The number of results of the consumer index 80 // map must match the number of loops of the producer. 81 AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); 82 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 83 return false; 84 85 // Currently support only operations with single result. 86 if (producer.getNumOutputs() != 1) 87 return false; 88 89 // Finally the index_map for the result must be invertible. For now just 90 // verify it is a permutation. 91 AffineMap producerResultIndexMap = 92 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 93 if (!producerResultIndexMap.isPermutation()) 94 return false; 95 96 // Ensure that the fusion does not remove size information required to 97 // get the loop bounds. For non-reduction generics, this is trivially the 98 // case due to the output operand. For reductions, we need to check that after 99 // the fusion, each loop dimension has at least one input that defines it. 100 if ((consumer.getNumReductionLoops())) { 101 llvm::BitVector coveredDims(consumer.getNumLoops(), false); 102 103 auto addToCoveredDims = [&](AffineMap map) { 104 for (auto result : map.getResults()) 105 if (auto dimExpr = result.dyn_cast<AffineDimExpr>()) 106 coveredDims[dimExpr.getPosition()] = true; 107 }; 108 109 for (auto pair : 110 llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) { 111 Value operand = std::get<0>(pair); 112 if (operand == consumerOpOperand->get()) 113 continue; 114 AffineMap operandMap = std::get<1>(pair); 115 addToCoveredDims(operandMap); 116 } 117 118 for (OpOperand *operand : producer.getInputOperands()) { 119 AffineMap newIndexingMap = 120 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 121 operand, producerResultIndexMap, consumerIndexMap); 122 addToCoveredDims(newIndexingMap); 123 } 124 if (!coveredDims.all()) 125 return false; 126 } 127 128 return true; 129 } 130 131 /// Generate the region of the fused tensor operation. The region of the fused 132 /// op must be empty. 133 static void 134 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, 135 AffineMap consumerToProducerLoopsMap, 136 OpOperand *consumerOpOperand, 137 unsigned nloops) { 138 auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp()); 139 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 140 // Build the region of the fused op. 141 Block &producerBlock = producer->getRegion(0).front(); 142 Block &consumerBlock = consumer->getRegion(0).front(); 143 Block *fusedBlock = new Block(); 144 fusedOp.region().push_back(fusedBlock); 145 BlockAndValueMapping mapper; 146 OpBuilder::InsertionGuard guard(rewriter); 147 rewriter.setInsertionPointToStart(fusedBlock); 148 149 // 2. Add an index operation for every fused loop dimension and use the 150 // `consumerToProducerLoopsMap` to map the producer indices. 151 if (producer.hasIndexSemantics()) { 152 // Add an index operation for every fused loop dimension. 153 unsigned numFusedOpLoops = 154 std::max(producer.getNumLoops(), consumer.getNumLoops()); 155 SmallVector<Value> fusedIndices; 156 fusedIndices.reserve(numFusedOpLoops); 157 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops), 158 std::back_inserter(fusedIndices), [&](uint64_t dim) { 159 return rewriter.create<IndexOp>(producer.getLoc(), dim); 160 }); 161 for (IndexOp indexOp : 162 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) { 163 Value newIndex = rewriter.create<mlir::AffineApplyOp>( 164 producer.getLoc(), 165 consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices); 166 mapper.map(indexOp.getResult(), newIndex); 167 } 168 } 169 // TODO: allow fusing the producer of an output operand. 170 assert(consumer.isInputTensor(consumerOpOperand) && 171 "expected producer of input operand"); 172 // 3. Consumer input operands up to consumerIdx (exclusive). 173 for (BlockArgument bbArg : consumerBlock.getArguments().take_front( 174 consumerOpOperand->getOperandNumber())) // input assumption. 175 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 176 177 // Replacing consumerIdx requires getting the cloned, yielded, value from 178 // the (cloned) producer block. This happens in step 9. 179 180 // 4. Splice in producer's input operands. 181 for (BlockArgument bbArg : 182 producerBlock.getArguments().take_front(producer.getNumInputs())) 183 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 184 185 // 4.b. Producer output operand/map that is fused needs to be mapped to the 186 // producer bbArg if it is an "initTensor" (i.e. its value is actually read). 187 assert(producer->getNumResults() == 1 && "expected single result producer"); 188 if (producer.isInitTensor(producer.getOutputOperand(0))) { 189 BlockArgument bbArg = producerBlock.getArguments() 190 .drop_front(producer.getNumInputs()) 191 // TODO: bbArg index of 192 .front(); 193 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 194 } 195 // 5. Remaining consumer's input operands (drop past index `consumerIdx`). 196 for (BlockArgument bbArg : 197 consumerBlock.getArguments() 198 .take_front(consumer.getNumInputs()) 199 .drop_front(consumerOpOperand->getOperandNumber() + 1)) 200 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 201 // 6. All of consumer's output operands. 202 for (BlockArgument bbArg : 203 consumerBlock.getArguments().take_back(consumer.getNumOutputs())) 204 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 205 // 7. All of producer's output operands except the one fused. 206 // TODO: allow fusion of multi-result producers. 207 assert(producer->getNumResults() == 1 && "expected single result producer"); 208 209 // 8. Clone all producer operations except for the yield and index operations 210 // to the fused operation. 211 for (auto &op : producerBlock.without_terminator()) { 212 if (!isa<IndexOp>(op)) 213 rewriter.clone(op, mapper); 214 } 215 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just 216 // forward the yield operand. 217 auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator()); 218 // TODO: allow fusion of multi-result producers. 219 assert(producer->getNumResults() == 1 && "expected single result producer"); 220 unsigned producerResultNumber = 0; 221 Value replacement = 222 mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); 223 // Sanity checks, if replacement is not already in the mapper then it must be 224 // produced outside. 225 if (replacement == yieldOp.getOperand(producerResultNumber)) { 226 if (auto bb = replacement.dyn_cast<BlockArgument>()) 227 assert(bb.getOwner() != &producerBlock && 228 "yielded block argument must have been mapped"); 229 else 230 assert(!producer->isAncestor(replacement.getDefiningOp()) && 231 "yielded value must have been mapped"); 232 } 233 mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), 234 replacement); 235 // 10. Clone operations from the consumer to the fused op. 236 for (auto &op : consumerBlock.getOperations()) 237 rewriter.clone(op, mapper); 238 239 // Sanity checks. 240 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && 241 "Ill-formed GenericOp region"); 242 } 243 244 static Optional<SmallVector<Value>> 245 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, 246 const ControlElementwiseOpsFusionFn &controlFn, 247 PatternRewriter &rewriter) { 248 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 249 if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || 250 !controlFn(producer->getResult(0), *consumerOpOperand)) 251 return llvm::None; 252 253 // TODO: allow fusing the producer of an output operand. 254 assert(consumer.isInputTensor(consumerOpOperand) && 255 "expected producer of input operand"); 256 257 // Compute the fused operands list and indexing maps. 258 SmallVector<Value> fusedOperands; 259 SmallVector<AffineMap> fusedIndexMaps; 260 fusedOperands.reserve(producer->getNumOperands() + 261 consumer->getNumOperands()); 262 fusedIndexMaps.reserve(producer->getNumOperands() + 263 consumer->getNumOperands()); 264 // In the following, numbering matches that of `generateFusedTensorOpRegion`. 265 // 3. Consumer input operands/maps up to consumerIdx (exclusive). 266 SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands(); 267 SmallVector<OpOperand *>::iterator it = 268 llvm::find(consumerInputs, consumerOpOperand); 269 assert(it != consumerInputs.end() && "expected to find the consumer operand"); 270 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { 271 fusedOperands.push_back(opOperand->get()); 272 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 273 } 274 // 4. Splice in producer's input operands/maps. 275 assert(producer->getNumResults() == 1 && "expected single result producer"); 276 AffineMap producerResultIndexMap = 277 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 278 for (OpOperand *opOperand : producer.getInputOperands()) { 279 fusedOperands.push_back(opOperand->get()); 280 // Compute indexing maps for the producer args in the fused operation. 281 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 282 opOperand, producerResultIndexMap, 283 consumer.getTiedIndexingMap(consumerOpOperand)); 284 fusedIndexMaps.push_back(map); 285 } 286 // 4.b. Producer output operand/map that is fused needs to be passed if it is 287 // an "initTensor" (i.e. its value is actually read). 288 assert(producer->getNumResults() == 1 && "expected single result producer"); 289 if (producer.isInitTensor(producer.getOutputOperand(0))) { 290 fusedOperands.push_back(producer.getOutputOperand(0)->get()); 291 // Compute indexing maps for the producer args in the fused operation. 292 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 293 producer.getOutputOperand(0), producerResultIndexMap, 294 consumer.getTiedIndexingMap(consumerOpOperand)); 295 fusedIndexMaps.push_back(map); 296 } 297 // 5. Remaining consumer's input operands/maps (drop past index 298 // `consumerIdx`). 299 for (OpOperand *opOperand : 300 llvm::make_range(std::next(it), consumerInputs.end())) { 301 fusedOperands.push_back(opOperand->get()); 302 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 303 } 304 // 6. All of consumer's output operands (skip operands: added by the builder). 305 for (OpOperand *opOperand : consumer.getOutputOperands()) 306 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 307 // 7. All of producer's output operands/maps except the one fused. 308 // TODO: allow fusion of multi-result producers. 309 assert(producer->getNumResults() == 1 && "expected single result producer"); 310 311 // Generate the fused op. 312 SmallVector<Value> consumerOutputs = consumer.getOutputOperands(); 313 auto fusedOp = rewriter.create<GenericOp>( 314 consumer.getLoc(), consumer->getResultTypes(), 315 /*inputs=*/fusedOperands, 316 // TODO: handle outputs. 317 consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 318 consumer.iterator_types(), 319 /*doc=*/nullptr, 320 /*library_call=*/nullptr); 321 if (!fusedOp.getShapesToLoopsMap()) { 322 // Fused op has invalid indexing maps. Typically this means something is off 323 // in the input, but going ahead here would result in verification errors. 324 // So cleanup and abort. 325 rewriter.eraseOp(fusedOp); 326 return llvm::None; 327 } 328 329 // Construct an AffineMap from consumer loops to producer loops. 330 // consumer loop -> tensor index 331 AffineMap consumerResultIndexMap = 332 consumer.getTiedIndexingMap(consumerOpOperand); 333 // tensor index -> producer loop 334 AffineMap invProducerResultIndexMap = 335 inversePermutation(producerResultIndexMap); 336 assert(invProducerResultIndexMap && 337 "expected producer result indexig map to be invertible"); 338 // consumer loop -> producer loop 339 AffineMap consumerToProducerLoopsMap = 340 invProducerResultIndexMap.compose(consumerResultIndexMap); 341 342 generateFusedElementwiseOpRegion(rewriter, fusedOp, 343 consumerToProducerLoopsMap, 344 consumerOpOperand, consumer.getNumLoops()); 345 return SmallVector<Value>(fusedOp->getResults()); 346 } 347 348 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` 349 /// provided, given the shape of the source tensor that corresponds to the 350 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions 351 /// are "row-major" ordered logically. 352 /// 353 /// For example: 354 /// 355 /// %0 = op ... : tensor<?x?x4x5xf32> 356 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` 357 /// 358 /// and reshape: 359 /// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] : 360 /// tensor<?x?x4x5xf32> into tensor<?x?xf32> 361 /// 362 /// would be rewritten into: 363 /// %0 = op ... : tensor<?x?x4x5xf32> 364 /// with output index_map 365 /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` 366 template <typename TensorReshapeOp> 367 static AffineMap linearizeCollapsedDims(AffineMap sourceMap, 368 TensorReshapeOp reshapeOp) { 369 constexpr bool isExpanding = 370 std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value; 371 ArrayRef<int64_t> sourceShape = 372 (isExpanding ? reshapeOp.getResultType().getShape() 373 : reshapeOp.getSrcType().getShape()); 374 SmallVector<AffineExpr> resultExprs; 375 ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults(); 376 MLIRContext *context = sourceMap.getContext(); 377 378 // Compute the result exprs based on the reassociation maps. 379 for (auto &indices : reshapeOp.getReassociationIndices()) { 380 // Assume that they are in-order and contiguous (already checked in 381 // verifier). 382 assert(!indices.empty()); 383 SmallVector<int64_t> sizes; 384 SmallVector<AffineExpr> dimExprs; 385 for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()), 386 sourceExprs.slice(indices[0], indices.size()))) { 387 if (std::get<0>(en) == 1) 388 continue; 389 sizes.push_back(std::get<0>(en)); 390 dimExprs.push_back(std::get<1>(en)); 391 } 392 AffineExpr linearizedExpr = 393 makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); 394 resultExprs.push_back(linearizedExpr); 395 } 396 // The new affine map cannot drop unused dimension but some new symbols may 397 // have been added. Create a map with at least as many dimensions/symbols as 398 // the original affine map. 399 int64_t maxDim = -1; 400 int64_t maxSym = -1; 401 getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym); 402 unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims()); 403 unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols()); 404 return AffineMap::get(numDims, numSyms, resultExprs, context); 405 } 406 407 // tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a 408 // producer). Fusing when operand has higher rank will require use of mods and 409 // divs in the indexing maps of the fused op which would make it non-invertible. 410 static bool isTensorReshapeOpFoldableByLinearization( 411 tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) { 412 if (!asProducer) 413 return false; 414 return useIndexMap.isPermutation(); 415 } 416 417 // tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a 418 // consumer). 419 static bool 420 isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp, 421 AffineMap useIndexMap, 422 bool asProducer) { 423 if (asProducer) 424 return false; 425 return useIndexMap.isPermutation(); 426 } 427 428 /// Check if the reshape operation is only expansion into/collapsing of 429 /// unit-dimension. 430 template <typename TensorReshapeOp> 431 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) { 432 constexpr bool isExpanding = 433 std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value; 434 ArrayRef<int64_t> expandedShape = 435 (isExpanding ? reshapeOp.getResultType().getShape() 436 : reshapeOp.getSrcType().getShape()); 437 for (auto &indices : reshapeOp.getReassociationIndices()) { 438 unsigned numUnitDims = 0; 439 for (int64_t position : indices) 440 if (expandedShape[position] == 1) 441 numUnitDims++; 442 if (numUnitDims != indices.size() - 1) 443 return false; 444 } 445 return true; 446 } 447 448 /// Conditions for folding a generic operation with a reshape op by expanding 449 /// the iteration space dimensionality for tensor operations. These are 450 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the 451 /// following fusion pattern. 452 /// 453 /// Consider 454 /// 455 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>) 456 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 457 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 458 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] 459 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] 460 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 461 /// 462 /// The reshape can be folded into the `genericOp` if its loop dimensionality 463 /// is increased to match the result (operand) of the tensor_expand_shape. 464 /// The indexing_map of the fused tensor in the `genericOp` and the 465 /// reassociation map helps compute the indexing maps of the modified op. 466 /// For the above example, based on the reassociation map it 467 /// can be concluded that 468 /// 469 /// - The loop used to access the first dimension of the fused tensor is split 470 /// into two. 471 /// - The loop used to access the second dimension of the fused tensor is kept 472 /// as is. 473 /// - The loop used to access the third dimension of the fused tensor is split 474 /// into three. 475 /// 476 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified 477 /// op, then 478 /// 479 /// d0 -> e0, e1 480 /// d1 -> e2, e3, e4 481 /// d2 -> e5 482 /// 483 /// substituting this, the generic op can be rewritten as 484 /// 485 /// %d = linalg.generic ins(%0, %1 : ) 486 /// indexing_maps = 487 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, 488 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, 489 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] 490 /// 491 /// Since operands to the linalg generic are now 5D, reshapes can be introduced 492 /// to make it consistent 493 /// 494 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]] 495 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 496 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]] 497 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 498 /// 499 /// The added reshapes are again expanding patterns, so they will get fused 500 /// with its producers if possible. 501 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, 502 OpOperand *fusableOpOperand) { 503 // Is fusable only if: 504 // - All the indexing maps for operands and results are projected 505 // permutations. 506 // - The fused tensor is not a scalar. 507 // - All the loops are parallel loops. 508 return genericOp.hasTensorSemantics() && 509 llvm::all_of(genericOp.indexing_maps().getValue(), 510 [](Attribute attr) { 511 return attr.cast<AffineMapAttr>() 512 .getValue() 513 .isProjectedPermutation(); 514 }) && 515 genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && 516 llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { 517 return attr.cast<StringAttr>().getValue() == 518 getParallelIteratorTypeName(); 519 }); 520 } 521 522 namespace { 523 /// Information needed to expand a generic operation to fold the reshape with 524 /// it. 525 class ExpansionInfo { 526 public: 527 // Computes the mapping from original dimensions of the op to the dimensions 528 // of the expanded op given the `indexingMap` of the fused operand/result of 529 // the generic op, the `reassocationMaps` of the reshape op and the shape of 530 // the expanded op. 531 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, 532 ArrayRef<AffineMap> reassociationMaps, 533 ArrayRef<int64_t> expandedShape, 534 ArrayRef<int64_t> collapsedShape, 535 PatternRewriter &rewriter); 536 unsigned getOrigOpNumDims() const { return reassociation.size(); } 537 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } 538 ReassociationIndicesRef getExpandedDims(unsigned i) const { 539 return reassociation[i]; 540 } 541 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { 542 return expandedShapeMap[i]; 543 } 544 ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; } 545 546 private: 547 /// Reassociation from the dimensions in the original operation to the 548 /// dimension of the expanded operation. 549 SmallVector<ReassociationIndices> reassociation; 550 /// Mapping from extent of loops in the original operation, to the extent of 551 /// loops in the expanded operation. 552 SmallVector<SmallVector<int64_t>> expandedShapeMap; 553 /// Extent of the loop in the original operation. 554 SmallVector<int64_t> originalLoopExtent; 555 unsigned expandedOpNumDims; 556 }; 557 } // namespace 558 559 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, 560 OpOperand *fusableOpOperand, 561 ArrayRef<AffineMap> reassociationMaps, 562 ArrayRef<int64_t> expandedShape, 563 ArrayRef<int64_t> collapsedShape, 564 PatternRewriter &rewriter) { 565 if (reassociationMaps.empty()) 566 return failure(); 567 AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); 568 569 Optional<SmallVector<int64_t, 4>> originalLoopRange = 570 linalgOp.getStaticLoopRanges(); 571 if (!originalLoopRange) 572 return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range"); 573 originalLoopExtent.assign(originalLoopRange->begin(), 574 originalLoopRange->end()); 575 576 reassociation.clear(); 577 expandedShapeMap.clear(); 578 // Compute the number of dimension in the expanded op that correspond to each 579 // dimension of the original op. 580 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1); 581 expandedShapeMap.resize(fusedIndexMap.getNumDims()); 582 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { 583 unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition(); 584 AffineMap foldedDims = reassociationMaps[resultExpr.index()]; 585 numExpandedDims[pos] = foldedDims.getNumResults(); 586 ArrayRef<int64_t> shape = 587 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); 588 expandedShapeMap[pos].assign(shape.begin(), shape.end()); 589 } 590 // The remaining dimensions remain the same. 591 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims())) 592 if (expandedShapeMap[i].empty()) 593 expandedShapeMap[i] = {originalLoopExtent[i]}; 594 595 // Compute reassociation map from the original op to the expanded op. 596 unsigned sum = 0; 597 reassociation.reserve(fusedIndexMap.getNumDims()); 598 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) { 599 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value()); 600 reassociation.emplace_back(seq.begin(), seq.end()); 601 sum += numFoldedDim.value(); 602 } 603 expandedOpNumDims = sum; 604 return success(); 605 } 606 607 /// Epanding the body of a linalg operation requires adaptations of the accessed 608 /// loop indices. Specifically, access of indices in the original operation need 609 /// to be replaced with linearizations of indices in the expanded op. That 610 /// requires the shape of the expanded dimensions to be static (at least all but 611 /// the most significant). For now check that these are all statically sized. 612 /// Note that this could be extended to handle dynamic case, but the 613 /// implementation below uses `affine.apply` which seems to have issues when the 614 /// shapes are not static. 615 LogicalResult isGenericOpExpandable(GenericOp genericOp, 616 const ExpansionInfo &expansionInfo, 617 PatternRewriter &rewriter) { 618 if (!genericOp.hasIndexSemantics()) 619 return success(); 620 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 621 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 622 if (expandedShape.size() == 1) 623 continue; 624 for (int64_t shape : expandedShape.drop_front()) { 625 if (ShapedType::isDynamic(shape)) { 626 return rewriter.notifyMatchFailure( 627 genericOp, "cannot expand due to index semantics and dynamic dims"); 628 } 629 } 630 } 631 return success(); 632 } 633 634 /// Return the indexing map to use in the expanded op for a given the 635 /// `indexingMap` of the original operation. 636 static AffineMap 637 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, 638 const ExpansionInfo &expansionInfo) { 639 SmallVector<AffineExpr> newExprs; 640 for (AffineExpr expr : indexingMap.getResults()) { 641 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 642 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>( 643 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { 644 return builder.getAffineDimExpr(static_cast<unsigned>(v)); 645 })); 646 newExprs.append(expandedExprs.begin(), expandedExprs.end()); 647 } 648 return AffineMap::get(expansionInfo.getExpandedOpNumDims(), 649 indexingMap.getNumSymbols(), newExprs, 650 builder.getContext()); 651 } 652 653 /// Return the type of the operand/result to use in the expanded op given the 654 /// type in the original op. 655 static RankedTensorType getExpandedType(RankedTensorType originalType, 656 AffineMap indexingMap, 657 const ExpansionInfo &expansionInfo) { 658 SmallVector<int64_t> expandedShape; 659 for (AffineExpr expr : indexingMap.getResults()) { 660 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 661 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); 662 expandedShape.append(dimExpansion.begin(), dimExpansion.end()); 663 } 664 return RankedTensorType::get(expandedShape, originalType.getElementType()); 665 } 666 667 /// Returns the reassociation maps to use in the `tensor.expand_shape` 668 /// operation to convert the operands of the original operation to operands of 669 /// the expanded operation. The same method is used to compute the 670 /// `tensor.collapse_shape` used to collapse the result of the expanded 671 /// op to get the value that can replace all uses of the results of the original 672 /// op. 673 static SmallVector<ReassociationIndices> 674 getReassociationForExpansion(AffineMap indexingMap, 675 const ExpansionInfo &expansionInfo) { 676 SmallVector<ReassociationIndices> reassociation; 677 unsigned numReshapeDims = 0; 678 for (AffineExpr expr : indexingMap.getResults()) { 679 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 680 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); 681 SmallVector<int64_t, 2> indices = llvm::to_vector<2>( 682 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims)); 683 reassociation.emplace_back(std::move(indices)); 684 numReshapeDims += numExpandedDims; 685 } 686 return reassociation; 687 } 688 689 /// Update the body of an expanded linalg operation having index semantics. The 690 /// indices of the original operation need to be recovered by linearizing the 691 /// indices of the correspoding dimensions of the expanded operation. For now it 692 /// is assumed that the shapes of the expanded operation needed for 693 /// linearization are static. 694 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, 695 Location loc, Region &fusedRegion, 696 const ExpansionInfo &expansionInfo) { 697 // Replace the original indices by the linearization of the expanded indices. 698 for (IndexOp indexOp : 699 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) { 700 ArrayRef<int64_t> expandedDims = 701 expansionInfo.getExpandedDims(indexOp.dim()); 702 assert(!expandedDims.empty() && "expected valid expansion info"); 703 704 // Skip index operations that are not affected by the expansion. 705 if (expandedDims.size() == 1 && 706 expandedDims.front() == (int64_t)indexOp.dim()) 707 continue; 708 709 // Linearize the expanded indices of the original index dimension. 710 OpBuilder::InsertionGuard guard(rewriter); 711 rewriter.setInsertionPointAfter(indexOp); 712 ArrayRef<int64_t> expandedDimsShape = 713 expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front(); 714 SmallVector<Value> expandedIndices; 715 expandedIndices.reserve(expandedDims.size() - 1); 716 llvm::transform( 717 expandedDims.drop_front(), std::back_inserter(expandedIndices), 718 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); 719 Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front()); 720 for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { 721 assert(!ShapedType::isDynamic(std::get<0>(it))); 722 AffineExpr idx, acc; 723 bindDims(rewriter.getContext(), idx, acc); 724 newIndex = rewriter.create<AffineApplyOp>( 725 indexOp.getLoc(), idx + acc * std::get<0>(it), 726 ValueRange{std::get<1>(it), newIndex}); 727 } 728 rewriter.replaceOp(indexOp, newIndex); 729 } 730 } 731 732 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op 733 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes 734 /// that those conditions have been satisfied. 735 static Optional<SmallVector<Value>> 736 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, 737 OpOperand *fusableOpOperand, 738 PatternRewriter &rewriter) { 739 assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) && 740 "preconditions for fuse operation failed"); 741 // Check if reshape is expanding or collapsing. 742 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp); 743 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp); 744 bool isExpanding = (expandingReshapeOp != nullptr); 745 RankedTensorType expandedType = isExpanding 746 ? expandingReshapeOp.getResultType() 747 : collapsingReshapeOp.getSrcType(); 748 RankedTensorType collapsedType = isExpanding 749 ? expandingReshapeOp.getSrcType() 750 : collapsingReshapeOp.getResultType(); 751 752 ExpansionInfo expansionInfo; 753 if (failed(expansionInfo.compute( 754 genericOp, fusableOpOperand, 755 isExpanding ? expandingReshapeOp.getReassociationMaps() 756 : collapsingReshapeOp.getReassociationMaps(), 757 expandedType.getShape(), collapsedType.getShape(), rewriter))) 758 return llvm::None; 759 760 if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) 761 return llvm::None; 762 763 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( 764 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) { 765 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); 766 })); 767 768 SmallVector<Value> expandedOpOperands; 769 expandedOpOperands.reserve(genericOp.getNumInputs()); 770 for (OpOperand *opOperand : genericOp.getInputOperands()) { 771 if (opOperand == fusableOpOperand) { 772 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() 773 : collapsingReshapeOp.src()); 774 continue; 775 } 776 if (genericOp.isInputTensor(opOperand)) { 777 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 778 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 779 RankedTensorType expandedOperandType = 780 getExpandedType(opOperandType, indexingMap, expansionInfo); 781 if (expandedOperandType != opOperand->get().getType()) { 782 // Reshape the operand to get the right type. 783 SmallVector<ReassociationIndices> reassociation = 784 getReassociationForExpansion(indexingMap, expansionInfo); 785 if (failed(reshapeLikeShapesAreCompatible( 786 [&](const Twine &msg) { 787 return rewriter.notifyMatchFailure(genericOp, msg); 788 }, 789 opOperandType.getShape(), expandedOperandType.getShape(), 790 reassociation, 791 /*isExpandingReshape=*/true))) 792 return llvm::None; 793 expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>( 794 genericOp.getLoc(), expandedOperandType, opOperand->get(), 795 reassociation)); 796 continue; 797 } 798 } 799 expandedOpOperands.push_back(opOperand->get()); 800 } 801 802 Location loc = genericOp.getLoc(); 803 SmallVector<Value> outputs; 804 for (OpOperand *opOperand : genericOp.getOutputOperands()) { 805 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 806 auto opOperandType = opOperand->get().getType().cast<RankedTensorType>(); 807 RankedTensorType expandedOutputType = 808 getExpandedType(opOperandType, indexingMap, expansionInfo); 809 if (expandedOutputType != opOperand->get().getType()) { 810 SmallVector<ReassociationIndices> reassociation = 811 getReassociationForExpansion(indexingMap, expansionInfo); 812 if (failed(reshapeLikeShapesAreCompatible( 813 [&](const Twine &msg) { 814 return rewriter.notifyMatchFailure(genericOp, msg); 815 }, 816 opOperandType.getShape(), expandedOutputType.getShape(), 817 reassociation, 818 /*isExpandingReshape=*/true))) 819 return llvm::None; 820 outputs.push_back(rewriter.create<tensor::ExpandShapeOp>( 821 genericOp.getLoc(), expandedOutputType, opOperand->get(), 822 reassociation)); 823 } 824 } 825 826 // The iterator types of the expanded op are all parallel. 827 SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(), 828 getParallelIteratorTypeName()); 829 830 TypeRange resultTypes = ValueRange(outputs).getTypes(); 831 auto fusedOp = 832 rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes, 833 /*inputs=*/expandedOpOperands, outputs, 834 expandedOpIndexingMaps, iteratorTypes); 835 Region &fusedRegion = fusedOp->getRegion(0); 836 Region &originalRegion = genericOp->getRegion(0); 837 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); 838 839 // Update the index accesses after the expansion. 840 updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); 841 842 // Reshape the result values to their original shape if this is a collapsing 843 // reshape folded into its consumer. 844 SmallVector<Value> resultVals; 845 for (OpResult opResult : genericOp->getOpResults()) { 846 int64_t resultNumber = opResult.getResultNumber(); 847 if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { 848 SmallVector<ReassociationIndices> reassociation = 849 getReassociationForExpansion( 850 genericOp.getTiedIndexingMap( 851 genericOp.getOutputOperand(resultNumber)), 852 expansionInfo); 853 resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>( 854 genericOp.getLoc(), opResult.getType(), 855 fusedOp->getResult(resultNumber), reassociation)); 856 } else { 857 resultVals.push_back(fusedOp->getResult(resultNumber)); 858 } 859 } 860 // Assuming a single result. 861 return resultVals; 862 } 863 864 namespace { 865 866 /// Pattern to fold tensor_expand_shape op with its consumer by using the source 867 /// of the reshape op as the operand in the consumer (instead of the result of 868 /// the tensor_collapse_shape). The corresponding index map in the consumer 869 /// needs to be modified to linearize the folded dimension. 870 /// 871 /// For example, 872 /// 873 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 874 /// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] 875 /// tensor<?x?x?xf32> into tensor<?x?x4x?xf32> 876 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } 877 /// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ... 878 /// -> tensor<?x?x4x?xf32> 879 /// 880 /// can be folded into 881 /// 882 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> 883 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 884 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } 885 /// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ... 886 /// -> tensor<?x?x4x?xf32> 887 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 888 struct FoldProducerReshapeOpByLinearization 889 : public OpRewritePattern<GenericOp> { 890 using OpRewritePattern<GenericOp>::OpRewritePattern; 891 892 LogicalResult matchAndRewrite(GenericOp genericOp, 893 PatternRewriter &rewriter) const override { 894 if (!genericOp.hasTensorSemantics()) 895 return failure(); 896 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 897 for (const auto &en : llvm::enumerate(inputOperands)) { 898 auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>(); 899 if (!reshapeOp) 900 continue; 901 902 if (!isTensorReshapeOpFoldableByLinearization( 903 reshapeOp, genericOp.getTiedIndexingMap(en.value()), 904 /*asProducer =*/true) || 905 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 906 continue; 907 908 // Compute the fused operands list, 909 SmallVector<Value> fusedOperands = genericOp.getInputOperands(); 910 fusedOperands[en.index()] = reshapeOp.src(); 911 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 912 llvm::append_range(fusedOperands, outputOperands); 913 914 // Compute indexing_maps for the fused operation. The indexing_maps for 915 // the operands of the consumers that arent fused are the same. 916 SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps(); 917 918 // Accepted consumer maps are either identity or permutation. 919 auto invMap = inversePermutation(fusedIndexMaps[en.index()]); 920 921 // Compute the indexing map to use for the result of the producer. 922 AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp); 923 // The modified map cannot have symbols. 924 if (modifiedMap.getNumSymbols()) 925 return failure(); 926 for (AffineExpr expr : modifiedMap.getResults()) { 927 if (!expr.isPureAffine()) 928 return failure(); 929 } 930 fusedIndexMaps[en.index()] = modifiedMap; 931 932 // Further check that the resulting index maps can be fused and 933 // inverted. Without this the resultant op is not legal. 934 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 935 return rewriter.notifyMatchFailure( 936 genericOp, "fused op loop bound computation failed"); 937 } 938 939 rewriter.startRootUpdate(genericOp); 940 genericOp->setOperands(fusedOperands); 941 genericOp.indexing_mapsAttr( 942 rewriter.getAffineMapArrayAttr(fusedIndexMaps)); 943 rewriter.finalizeRootUpdate(genericOp); 944 return success(); 945 } 946 return failure(); 947 } 948 }; 949 950 static SmallVector<ReassociationIndices> 951 getReassociationIndices(ArrayRef<AffineMap> maps) { 952 SmallVector<ReassociationIndices> reassociation; 953 for (AffineMap map : maps) { 954 ReassociationIndices indices; 955 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 956 unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition(); 957 indices.push_back(pos); 958 } 959 reassociation.push_back(indices); 960 } 961 return reassociation; 962 } 963 964 /// Pattern to move rank reducing reshape after an elementwise linalg generic 965 /// op. This is useful to expose more fusion opportunities between named ops and 966 /// generic ops. This can only be done if there is no broadcast or permuation 967 /// within the dimensions we need to merge. 968 /// 969 /// For example, 970 /// 971 /// %0 = tensor.expand_shape %A [[0, 1], [2]] 972 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 973 /// %2 = linalg.generic {indexing_maps = [ 974 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 975 /// affine_map<(d0, d1, d2) -> (d2)>, 976 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = 977 /// ["parallel", "parallel", "parallel"]} { 978 /// } -> tensor<112x112x16xf32> 979 /// 980 /// into 981 /// 982 /// %2 = linalg.generic {indexing_maps = [ 983 /// affine_map<(d0, d1) -> (d0, d1)>, 984 /// affine_map<(d0, d1) -> (d1)>, 985 /// affine_map<(d0, d1) -> (d0, d1)>], 986 /// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 987 /// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) { 988 /// } -> tensor<12544x16xf32> 989 /// %3 = tensor.expand_shape %2 [[0, 1], [2]] 990 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 991 struct PushExpandingReshape : public OpRewritePattern<GenericOp> { 992 using OpRewritePattern<GenericOp>::OpRewritePattern; 993 994 LogicalResult matchAndRewrite(GenericOp genericOp, 995 PatternRewriter &rewriter) const override { 996 // Only apply to elementwise linalg on tensor. 997 if (!genericOp.hasTensorSemantics() || 998 genericOp.getNumParallelLoops() != genericOp.getNumLoops()) 999 return failure(); 1000 // Only support identity output maps. It could be extended to permuations if 1001 // needed. 1002 if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) { 1003 return !genericOp.getTiedIndexingMap(opOperand).isIdentity(); 1004 })) 1005 return failure(); 1006 int64_t destRank = genericOp.getNumParallelLoops(); 1007 SmallVector<Value> newOperands = genericOp.getInputOperands(); 1008 tensor::ExpandShapeOp reshapeFound; 1009 // 1. Look for tensor_expand_shape operands and figure out save the 1010 // dimensions merged. 1011 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 1012 for (const auto &en : llvm::enumerate(inputOperands)) { 1013 auto reshapeOp = 1014 en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>(); 1015 if (!reshapeOp) 1016 continue; 1017 // TODO: We could support non-identity map as long as the merged 1018 // dimensions are still contiguous. 1019 if (!genericOp.getTiedIndexingMap(en.value()).isIdentity()) 1020 continue; 1021 if (reshapeFound) { 1022 // Only support a second reshape op if it has the same reassociate maps. 1023 if (reshapeFound.getReassociationMaps() == 1024 reshapeOp.getReassociationMaps()) 1025 newOperands[en.index()] = reshapeOp.src(); 1026 continue; 1027 } 1028 reshapeFound = reshapeOp; 1029 newOperands[en.index()] = reshapeOp.src(); 1030 } 1031 if (!reshapeFound) 1032 return failure(); 1033 1034 // Calculate the reassociation indices and rassociated reverse map. 1035 SmallVector<ReassociationIndices> reassociation = 1036 getReassociationIndices(reshapeFound.getReassociationMaps()); 1037 SmallVector<unsigned> remap(destRank); 1038 for (auto &indices : llvm::enumerate(reassociation)) { 1039 for (int64_t index : indices.value()) { 1040 remap[index] = indices.index(); 1041 } 1042 } 1043 // 2. Verify that we can merge the dimensions in the linalg and that we 1044 // don't need to create new reshapes operands. Inserting new reshape 1045 // operands would defeat the purpose of the transformation. 1046 for (const auto &en : llvm::enumerate(inputOperands)) { 1047 if (en.value()->get() == newOperands[en.index()]) { 1048 AffineMap map = genericOp.getTiedIndexingMap(en.value()); 1049 for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { 1050 if (reassociation[remap[map.getDimPosition(i)]].size() > 1) 1051 return failure(); 1052 } 1053 } 1054 } 1055 1056 // 3. Calculate the affine map remapping and the reassociation to apply to 1057 // output tensors. 1058 SmallVector<AffineMap> newMaps; 1059 unsigned newRank = reassociation.size(); 1060 for (auto map : genericOp.getIndexingMaps()) { 1061 SmallVector<AffineExpr> newExprs; 1062 for (auto expr : map.getResults()) { 1063 unsigned position = expr.template cast<AffineDimExpr>().getPosition(); 1064 // Skip dimension merged except for the last of the group. 1065 if (reassociation[remap[position]].back() == position) { 1066 newExprs.push_back( 1067 getAffineDimExpr(remap[position], genericOp.getContext())); 1068 } 1069 } 1070 newMaps.push_back( 1071 AffineMap::get(newRank, 0, newExprs, genericOp.getContext())); 1072 } 1073 1074 // 4. Reshape the output tensors. 1075 SmallVector<Value> newOutputs; 1076 SmallVector<Type> newOutputTypes; 1077 for (auto output : genericOp.outputs()) { 1078 auto newOutputType = RankedTensorType::get( 1079 reshapeFound.getSrcType().getShape(), 1080 output.getType().template cast<RankedTensorType>().getElementType()); 1081 Value newOutput = rewriter.create<tensor::CollapseShapeOp>( 1082 genericOp->getLoc(), newOutputType, output, reassociation); 1083 newOutputTypes.push_back(newOutputType); 1084 newOutputs.push_back(newOutput); 1085 } 1086 // 5. Create a new generic op with lowerer rank. 1087 SmallVector<StringRef> iteratorTypes(newRank, 1088 getParallelIteratorTypeName()); 1089 auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes, 1090 newOperands, newOutputs, newMaps, 1091 iteratorTypes); 1092 rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), 1093 newOp.region().begin()); 1094 // 6. Reshape the so that the type matches the uses. 1095 SmallVector<Value> newResults; 1096 for (const auto &result : llvm::enumerate(newOp->getResults())) { 1097 newResults.push_back(rewriter.create<tensor::ExpandShapeOp>( 1098 genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()], 1099 result.value(), reassociation)); 1100 } 1101 rewriter.replaceOp(genericOp, newResults); 1102 return success(); 1103 } 1104 }; 1105 1106 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op, 1107 /// when the reshape op is collapsing dimensions. The dimensionality of the loop 1108 /// in the consumer is expanded. 1109 class FoldWithProducerReshapeOpByExpansion 1110 : public OpRewritePattern<GenericOp> { 1111 public: 1112 FoldWithProducerReshapeOpByExpansion( 1113 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1114 PatternBenefit benefit = 1) 1115 : OpRewritePattern<GenericOp>(context, benefit), 1116 controlFoldingReshapes(std::move(foldReshapes)) {} 1117 1118 LogicalResult matchAndRewrite(GenericOp genericOp, 1119 PatternRewriter &rewriter) const override { 1120 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 1121 tensor::CollapseShapeOp reshapeOp = 1122 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>(); 1123 if (!reshapeOp) 1124 continue; 1125 // Fold only if 1126 // - The tensor reshape op is folding. 1127 // - All constraints of fusing with reshape by expansion are met. 1128 if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || 1129 (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) 1130 continue; 1131 1132 Optional<SmallVector<Value>> replacementValues = 1133 fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); 1134 if (!replacementValues) 1135 return failure(); 1136 rewriter.replaceOp(genericOp, replacementValues.getValue()); 1137 return success(); 1138 } 1139 return failure(); 1140 } 1141 1142 private: 1143 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1144 }; 1145 1146 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its 1147 /// producer. The corresponding index map in the consumer needs to be modified 1148 /// to linearize the folded dimension. 1149 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 1150 struct FoldConsumerReshapeOpByLinearization 1151 : public OpRewritePattern<TensorReshapeOp> { 1152 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 1153 1154 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 1155 PatternRewriter &rewriter) const override { 1156 GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>(); 1157 if (!producer || !producer.hasTensorSemantics() || 1158 producer.getNumOutputs() != 1 || 1159 !isTensorReshapeOpFoldableByLinearization( 1160 reshapeOp, 1161 producer.getTiedIndexingMap(producer.getOutputOperand(0)), 1162 /*asProducer =*/false) || 1163 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 1164 return failure(); 1165 // The indexing_maps for the operands of the fused operation are same as 1166 // those for the operands of the producer. 1167 SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps(); 1168 1169 auto invMap = inversePermutation( 1170 producer.getTiedIndexingMap(producer.getOutputOperand(0))); 1171 1172 // Compute the indexing map to use for the operand of the producer. 1173 AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp); 1174 for (AffineExpr expr : modifiedMap.getResults()) { 1175 if (!expr.isPureAffine()) { 1176 return rewriter.notifyMatchFailure( 1177 producer, "fused op indexing map is not affine"); 1178 } 1179 } 1180 fusedIndexMaps.back() = modifiedMap; 1181 1182 // Further check that the resulting index maps can be fused and 1183 // inverted. Without this the resultant op is not legal. 1184 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1185 return rewriter.notifyMatchFailure( 1186 producer, "fused op loop bound computation failed"); 1187 } 1188 1189 Location loc = producer.getLoc(); 1190 SmallVector<Value> inputOperands = producer.getInputOperands(); 1191 Value output = rewriter.create<TensorReshapeOp>( 1192 loc, producer.getOutputOperand(0)->get(), 1193 reshapeOp.getReassociationExprs()); 1194 auto fusedOp = rewriter.create<GenericOp>( 1195 loc, reshapeOp.getResultType(), 1196 /*inputs=*/inputOperands, 1197 // TODO: handle outputs. 1198 /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1199 producer.iterator_types(), 1200 /*doc=*/nullptr, 1201 /*library_call=*/nullptr); 1202 auto &fusedRegion = fusedOp->getRegion(0); 1203 rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion, 1204 fusedRegion.begin()); 1205 rewriter.replaceOp(reshapeOp, fusedOp->getResults()); 1206 return success(); 1207 } 1208 }; 1209 1210 /// Pattern to fold a tensor_expand_shape op with its producer generic op 1211 /// by expanding the dimensionality of the loop in the producer op. 1212 struct FoldReshapeWithGenericOpByExpansion 1213 : public OpRewritePattern<tensor::ExpandShapeOp> { 1214 1215 FoldReshapeWithGenericOpByExpansion( 1216 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1217 PatternBenefit benefit = 1) 1218 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), 1219 controlFoldingReshapes(std::move(foldReshapes)) {} 1220 1221 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, 1222 PatternRewriter &rewriter) const override { 1223 // Fold only if all constraints of fusing with reshape by expansion are met. 1224 GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>(); 1225 if (!producer || producer.getNumOutputs() != 1 || 1226 !isFusableWithReshapeByDimExpansion(producer, 1227 producer.getOutputOperand(0)) || 1228 !controlFoldingReshapes(producer->getResult(0), 1229 reshapeOp->getOpOperand(0))) 1230 return failure(); 1231 Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion( 1232 producer, reshapeOp, producer.getOutputOperand(0), rewriter); 1233 if (!replacementValues) 1234 return failure(); 1235 rewriter.replaceOp(reshapeOp, replacementValues.getValue()); 1236 return success(); 1237 } 1238 1239 private: 1240 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1241 }; 1242 1243 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not 1244 /// handle cases where the constant is not single-valued. 1245 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> { 1246 public: 1247 FoldScalarOrSplatConstant(MLIRContext *context, 1248 ControlElementwiseOpsFusionFn &fun, 1249 PatternBenefit benefit = 1) 1250 : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {} 1251 1252 LogicalResult matchAndRewrite(GenericOp genericOp, 1253 PatternRewriter &rewriter) const override { 1254 if (!genericOp.hasTensorSemantics()) 1255 return failure(); 1256 for (OpOperand *opOperand : genericOp.getInputOperands()) { 1257 Operation *def = opOperand->get().getDefiningOp(); 1258 Attribute constantAttr; 1259 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { 1260 { 1261 DenseElementsAttr splatAttr; 1262 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) && 1263 splatAttr.isSplat() && 1264 splatAttr.getType().getElementType().isIntOrFloat()) { 1265 constantAttr = splatAttr.getSplatValue<Attribute>(); 1266 return true; 1267 } 1268 } 1269 { 1270 IntegerAttr intAttr; 1271 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) { 1272 constantAttr = intAttr; 1273 return true; 1274 } 1275 } 1276 { 1277 FloatAttr floatAttr; 1278 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) { 1279 constantAttr = floatAttr; 1280 return true; 1281 } 1282 } 1283 return false; 1284 }; 1285 1286 auto resultValue = opOperand->get().dyn_cast<OpResult>(); 1287 if (!def || !resultValue || !isScalarOrSplatConstantOp(def) || 1288 !controlFn(resultValue, *opOperand)) 1289 continue; 1290 1291 // The operands and the indexing_maps of the fused operation the same as 1292 // the operands and indexing_maps of the generic operations with the 1293 // values at the constant index dropped. 1294 SmallVector<AffineMap> fusedIndexMaps; 1295 SmallVector<Value> fusedOperands; 1296 SmallVector<Location> fusedLocs{genericOp.getLoc()}; 1297 fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); 1298 fusedOperands.reserve(genericOp.getNumInputs()); 1299 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); 1300 for (OpOperand *inputOperand : genericOp.getInputOperands()) { 1301 if (inputOperand == opOperand) 1302 continue; 1303 Value inputValue = inputOperand->get(); 1304 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); 1305 fusedOperands.push_back(inputValue); 1306 fusedLocs.push_back(inputValue.getLoc()); 1307 } 1308 for (OpOperand *outputOperand : genericOp.getOutputOperands()) 1309 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); 1310 1311 // Check if the operation shapes to loops map is computable. 1312 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1313 return rewriter.notifyMatchFailure( 1314 genericOp, "fused op loop bound computation failed"); 1315 } 1316 1317 // Create a constant scalar value from the splat constant. 1318 Value scalarConstant = rewriter.create<arith::ConstantOp>( 1319 def->getLoc(), constantAttr, constantAttr.getType()); 1320 1321 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 1322 auto fusedOp = rewriter.create<GenericOp>( 1323 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), 1324 /*inputs=*/fusedOperands, 1325 /*outputs=*/outputOperands, 1326 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1327 genericOp.iterator_types(), 1328 /*doc=*/nullptr, 1329 /*library_call=*/nullptr); 1330 1331 // Map the block argument corresponding to the replaced argument with the 1332 // scalar constant. 1333 Region ®ion = genericOp->getRegion(0); 1334 Block &entryBlock = *region.begin(); 1335 BlockAndValueMapping mapping; 1336 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), 1337 scalarConstant); 1338 Region &fusedRegion = fusedOp->getRegion(0); 1339 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), 1340 mapping); 1341 rewriter.replaceOp(genericOp, fusedOp->getResults()); 1342 return success(); 1343 } 1344 return failure(); 1345 } 1346 1347 private: 1348 ControlElementwiseOpsFusionFn controlFn; 1349 }; 1350 1351 /// Base class for constant folding linalg.generic ops with N inputs, 1 output, 1352 /// and permutation indexing maps. 1353 /// 1354 /// `ConcreteType` should provide methods with signatures 1355 /// 1356 /// ```c++ 1357 /// bool matchIndexingMaps(GenericOp genericOp) const; 1358 /// RegionComputationFn getRegionComputeFn(GenericOp) const; 1359 /// ``` 1360 /// 1361 /// The latter inspects the region and returns the computation inside as a 1362 /// functor. The functor will be invoked with constant elements for all inputs 1363 /// and should return the corresponding computea constant element for output. 1364 template <typename ConcreteType> 1365 class FoldConstantBase : public OpRewritePattern<GenericOp> { 1366 public: 1367 struct APIntOrFloat { 1368 Optional<APInt> apInt; 1369 Optional<APFloat> apFloat; 1370 }; 1371 struct APIntOrFloatArray { 1372 SmallVector<APInt> apInts; 1373 SmallVector<APFloat> apFloats; 1374 }; 1375 using RegionComputationFn = 1376 std::function<APIntOrFloat(const APIntOrFloatArray &)>; 1377 1378 FoldConstantBase(MLIRContext *context, 1379 const ControlElementwiseOpsFusionFn &controlFn, 1380 PatternBenefit benefit = 1) 1381 : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {} 1382 1383 LogicalResult matchAndRewrite(GenericOp genericOp, 1384 PatternRewriter &rewriter) const override { 1385 if (genericOp.hasBufferSemantics()) 1386 return failure(); 1387 1388 // Only support ops generating one output for now. 1389 if (genericOp.getNumOutputs() != 1) 1390 return failure(); 1391 1392 auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>(); 1393 // Require the output types to be static give we are generating constants. 1394 if (!outputType || !outputType.hasStaticShape()) 1395 return failure(); 1396 1397 if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) { 1398 return operand->get().getType().isa<ShapedType>(); 1399 })) 1400 return failure(); 1401 1402 // Make sure all element types are the same. 1403 auto getOperandElementType = [](OpOperand *operand) { 1404 return operand->get().getType().cast<ShapedType>().getElementType(); 1405 }; 1406 if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(), 1407 getOperandElementType))) 1408 return failure(); 1409 1410 // We can only handle the case where we have int/float elements. 1411 auto elementType = outputType.getElementType(); 1412 if (!elementType.isIntOrFloat()) 1413 return failure(); 1414 1415 // Require all indexing maps to be permutations for now. This is common and 1416 // it simplifies input/output access greatly: we can do the data shuffling 1417 // entirely in the compiler, without needing to turn all indices into 1418 // Values, and then do affine apply on them, and then match back the 1419 // constant again. 1420 if (!llvm::all_of(genericOp.getIndexingMaps(), 1421 [](AffineMap map) { return map.isPermutation(); })) 1422 return failure(); 1423 1424 for (OpOperand *operand : genericOp.getOutputOperands()) { 1425 if (genericOp.payloadUsesValueFromOperand(operand)) 1426 return failure(); 1427 } 1428 1429 // Further check the indexing maps are okay for the ConcreteType. 1430 if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp)) 1431 return failure(); 1432 1433 // Defer to the concrete type to check the region and discover the 1434 // computation inside. 1435 RegionComputationFn computeFn = 1436 static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp); 1437 if (!computeFn) 1438 return failure(); 1439 1440 // All inputs should be constants. 1441 int numInputs = genericOp.getNumInputs(); 1442 SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs); 1443 for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) { 1444 if (!matchPattern(operand.value()->get(), 1445 m_Constant(&inputValues[operand.index()]))) 1446 return failure(); 1447 } 1448 1449 // Identified this as a potential candidate for folding. Now check the 1450 // policy to see whether we are allowed to proceed. 1451 for (int i = 0; i < numInputs; ++i) { 1452 OpOperand *consumer = genericOp.getInputOperand(i); 1453 OpResult producer = consumer->get().cast<OpResult>(); 1454 if (!controlFn(producer, *consumer)) 1455 return failure(); 1456 } 1457 1458 auto linalgOp = cast<LinalgOp>(genericOp.getOperation()); 1459 SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes(); 1460 int64_t numElements = outputType.getNumElements(); 1461 1462 // Use APInt/APFloat instead of Attribute here for constructing the output. 1463 // This helps to avoid blowing up compiler memory usage: Attributes would 1464 // unify the following cases but they have lifetime as the MLIRContext. 1465 SmallVector<APInt> intOutputValues; 1466 SmallVector<APFloat> fpOutputValues; 1467 if (elementType.template isa<FloatType>()) 1468 fpOutputValues.resize(numElements, APFloat(0.f)); 1469 else 1470 intOutputValues.resize(numElements); 1471 1472 // Return the constant dim positions from the given permutation map. 1473 auto getDimPositions = [](AffineMap map) { 1474 SmallVector<unsigned> dims; 1475 dims.reserve(map.getNumResults()); 1476 for (AffineExpr result : map.getResults()) { 1477 dims.push_back(result.cast<AffineDimExpr>().getPosition()); 1478 } 1479 return dims; 1480 }; 1481 1482 SmallVector<SmallVector<unsigned>> inputDims; 1483 for (int i = 0; i < numInputs; ++i) 1484 inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i])); 1485 auto outputDims = getDimPositions(genericOp.getIndexingMaps().back()); 1486 auto outputShape = outputType.getShape(); 1487 1488 // Allocate small vectors for index delinearization. Initial values do not 1489 // matter here as they will be overwritten later. 1490 SmallVector<uint64_t> indices(loopBounds.size(), 0); 1491 SmallVector<uint64_t> dstIndices(loopBounds.size(), 0); 1492 SmallVector<SmallVector<uint64_t>> srcIndices( 1493 numInputs, SmallVector<uint64_t>(loopBounds.size(), 0)); 1494 SmallVector<uint64_t> srcLinearIndices(numInputs, 0); 1495 uint64_t dstLinearIndex = 0; 1496 1497 // Allocate spaces for compute function inputs. Initial values do not matter 1498 // here as they will be overwritten later. 1499 APIntOrFloatArray computeFnInputs; 1500 1501 auto inputShapes = llvm::to_vector<4>( 1502 llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { 1503 return operand->get().getType().cast<ShapedType>().getShape(); 1504 })); 1505 1506 // Given a `linearIndex`, remap it to a linear index to access linalg op 1507 // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, 1508 // `srcLinearIndices`, `dstLinearIndex` in place. 1509 auto computeRemappedLinearIndex = [&](int linearIndex) { 1510 int totalCount = linearIndex; 1511 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 1512 indices[dim] = totalCount % loopBounds[dim]; 1513 totalCount /= loopBounds[dim]; 1514 } 1515 1516 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 1517 for (int i = 0; i < numInputs; ++i) 1518 srcIndices[i][dim] = indices[inputDims[i][dim]]; 1519 dstIndices[dim] = indices[outputDims[dim]]; 1520 } 1521 1522 dstLinearIndex = dstIndices.front(); 1523 for (int i = 0; i < numInputs; ++i) 1524 srcLinearIndices[i] = srcIndices[i].front(); 1525 1526 for (int dim = 1; dim < outputType.getRank(); ++dim) { 1527 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; 1528 for (int i = 0; i < numInputs; ++i) 1529 srcLinearIndices[i] = 1530 srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; 1531 } 1532 }; 1533 1534 bool isFloat = elementType.isa<FloatType>(); 1535 if (isFloat) { 1536 SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges; 1537 for (int i = 0; i < numInputs; ++i) 1538 inFpRanges.push_back(inputValues[i].getValues<APFloat>()); 1539 1540 computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); 1541 1542 // Transpose the input constant. Because we don't know its rank in 1543 // advance, we need to loop over the range [0, element count) and 1544 // delinearize the index. 1545 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 1546 computeRemappedLinearIndex(linearIndex); 1547 1548 // Collect constant elements for all inputs at this loop iteration. 1549 for (int i = 0; i < numInputs; ++i) 1550 computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; 1551 1552 // Invoke the computation to get the corresponding constant output 1553 // element. 1554 fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; 1555 } 1556 } else { 1557 SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges; 1558 for (int i = 0; i < numInputs; ++i) 1559 inIntRanges.push_back(inputValues[i].getValues<APInt>()); 1560 1561 computeFnInputs.apInts.resize(numInputs); 1562 1563 // Transpose the input constant. Because we don't know its rank in 1564 // advance, we need to loop over the range [0, element count) and 1565 // delinearize the index. 1566 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 1567 computeRemappedLinearIndex(linearIndex); 1568 1569 // Collect constant elements for all inputs at this loop iteration. 1570 for (int i = 0; i < numInputs; ++i) 1571 computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; 1572 1573 // Invoke the computation to get the corresponding constant output 1574 // element. 1575 intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; 1576 } 1577 } 1578 1579 DenseElementsAttr outputAttr = 1580 isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) 1581 : DenseElementsAttr::get(outputType, intOutputValues); 1582 1583 rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr); 1584 return success(); 1585 } 1586 1587 private: 1588 ControlElementwiseOpsFusionFn controlFn; 1589 }; 1590 1591 // Folds linalg.generic ops that are actually transposes on constant values. 1592 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> { 1593 using FoldConstantBase::FoldConstantBase; 1594 1595 bool matchIndexingMaps(GenericOp genericOp) const { 1596 // We should have one input and one output. 1597 return genericOp.getIndexingMaps().size() == 2; 1598 } 1599 1600 RegionComputationFn getRegionComputeFn(GenericOp genericOp) const { 1601 // Make sure the region only contains a yield op. 1602 Block &body = genericOp.region().front(); 1603 if (!llvm::hasSingleElement(body)) 1604 return nullptr; 1605 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); 1606 if (!yieldOp) 1607 return nullptr; 1608 1609 // The yield op should return the block argument corresponds to the input. 1610 for (Value yieldVal : yieldOp.values()) { 1611 auto yieldArg = yieldVal.dyn_cast<BlockArgument>(); 1612 if (!yieldArg || yieldArg.getOwner() != &body) 1613 return nullptr; 1614 if (yieldArg.getArgNumber() != 0) 1615 return nullptr; 1616 } 1617 1618 // No computation; just return the orginal value. 1619 return [](const APIntOrFloatArray &inputs) { 1620 if (inputs.apFloats.empty()) 1621 return APIntOrFloat{inputs.apInts.front(), llvm::None}; 1622 return APIntOrFloat{llvm::None, inputs.apFloats.front()}; 1623 }; 1624 } 1625 1626 ControlElementwiseOpsFusionFn controlFn; 1627 }; 1628 1629 } // namespace 1630 1631 static Optional<SmallVector<Value>> 1632 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, 1633 GenericOp producer, 1634 const ControlElementwiseOpsFusionFn &controlFn) { 1635 if (producer->getNumResults() != 1) 1636 return llvm::None; 1637 1638 return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, 1639 rewriter); 1640 } 1641 1642 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, 1643 OpOperand &consumer) { 1644 if (auto producerCollapseOp = 1645 dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) { 1646 return !isUnitDimExpansionOnly(producerCollapseOp); 1647 } 1648 if (auto consumerExpandOp = 1649 dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) { 1650 return !isUnitDimExpansionOnly(consumerExpandOp); 1651 } 1652 return true; 1653 } 1654 1655 namespace { 1656 /// Patterns to fuse a generic op, with the producer of its operands. 1657 class FuseElementwiseOps : public OpRewritePattern<GenericOp> { 1658 public: 1659 FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, 1660 PatternBenefit benefit = 1) 1661 : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {} 1662 1663 LogicalResult matchAndRewrite(GenericOp genericOp, 1664 PatternRewriter &rewriter) const override { 1665 // Find the first operand that is defined by another generic op on tensors. 1666 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 1667 auto producer = 1668 dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp()); 1669 if (!producer || !producer.hasTensorSemantics()) 1670 continue; 1671 Optional<SmallVector<Value>> fusedOpResults = 1672 fuseElementwiseOps(rewriter, opOperand, producer, controlFn); 1673 if (fusedOpResults) { 1674 rewriter.replaceOp(genericOp, *fusedOpResults); 1675 return success(); 1676 } 1677 } 1678 return failure(); 1679 } 1680 1681 private: 1682 ControlElementwiseOpsFusionFn controlFn; 1683 }; 1684 1685 /// Pass that fuses generic ops on tensors. Used only for testing. 1686 struct LinalgElementwiseOpFusionPass 1687 : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> { 1688 void runOnOperation() override { 1689 Operation *op = getOperation(); 1690 RewritePatternSet patterns(op->getContext()); 1691 ControlElementwiseOpsFusionFn allowFoldingFn = 1692 [](const OpResult &producer, const OpOperand &consumer) { 1693 return true; 1694 }; 1695 populateElementwiseOpsFusionPatterns( 1696 patterns, 1697 LinalgElementwiseFusionOptions().setControlFoldingReshapes( 1698 allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape)); 1699 1700 // Use TopDownTraversal for compile time reasons 1701 GreedyRewriteConfig grc; 1702 grc.useTopDownTraversal = true; 1703 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), 1704 grc); 1705 } 1706 }; 1707 1708 /// Pass to test folding of reshape ops with generic ops by linearization. 1709 struct FoldReshapeOpsByLinearizationPass 1710 : public LinalgFoldReshapeOpsByLinearizationBase< 1711 FoldReshapeOpsByLinearizationPass> { 1712 void runOnOperation() override { 1713 Operation *op = getOperation(); 1714 RewritePatternSet patterns(op->getContext()); 1715 populateFoldReshapeOpsByLinearizationPatterns(patterns); 1716 if (allowFoldingUnitDimReshapes) { 1717 populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns); 1718 } 1719 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); 1720 } 1721 }; 1722 1723 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if 1724 /// the value of the `outs` operand is not used within the op. This is only 1725 /// implemented for `linalg.generic` operations for now, but should hold for all 1726 /// linalg structured ops. 1727 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> { 1728 using OpRewritePattern<GenericOp>::OpRewritePattern; 1729 1730 LogicalResult matchAndRewrite(GenericOp op, 1731 PatternRewriter &rewriter) const override { 1732 rewriter.startRootUpdate(op); 1733 bool modifiedOutput = false; 1734 Location loc = op.getLoc(); 1735 for (OpOperand *opOperand : op.getOutputOperands()) { 1736 if (!op.payloadUsesValueFromOperand(opOperand)) { 1737 Value operandVal = opOperand->get(); 1738 auto operandType = operandVal.getType().dyn_cast<RankedTensorType>(); 1739 if (!operandType) 1740 continue; 1741 1742 // If outs is already an `init_tensor` operation, nothing to do. 1743 auto definingOp = operandVal.getDefiningOp<InitTensorOp>(); 1744 if (definingOp) 1745 continue; 1746 modifiedOutput = true; 1747 SmallVector<Value> dynamicDims; 1748 for (const auto &dim : llvm::enumerate(operandType.getShape())) { 1749 if (dim.value() != ShapedType::kDynamicSize) 1750 continue; 1751 dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>( 1752 loc, operandVal, dim.index())); 1753 } 1754 Value initTensor = rewriter.create<InitTensorOp>( 1755 loc, dynamicDims, operandType.getShape(), 1756 operandType.getElementType()); 1757 op->setOperand(opOperand->getOperandNumber(), initTensor); 1758 } 1759 } 1760 if (!modifiedOutput) { 1761 rewriter.cancelRootUpdate(op); 1762 return failure(); 1763 } 1764 rewriter.finalizeRootUpdate(op); 1765 return success(); 1766 } 1767 }; 1768 1769 } // namespace 1770 1771 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( 1772 RewritePatternSet &patterns) { 1773 patterns 1774 .add<FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>, 1775 FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>, 1776 FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>, 1777 FoldConsumerReshapeOpByLinearization<false, tensor::ExpandShapeOp>>( 1778 patterns.getContext()); 1779 } 1780 1781 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( 1782 RewritePatternSet &patterns) { 1783 patterns 1784 .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>, 1785 FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>, 1786 FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>, 1787 FoldConsumerReshapeOpByLinearization<true, tensor::ExpandShapeOp>>( 1788 patterns.getContext()); 1789 } 1790 1791 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( 1792 RewritePatternSet &patterns, 1793 const ControlElementwiseOpsFusionFn &controlFoldingReshapes) { 1794 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(), 1795 controlFoldingReshapes); 1796 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), 1797 controlFoldingReshapes); 1798 } 1799 1800 void mlir::linalg::populateElementwiseOpsFusionPatterns( 1801 RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { 1802 auto *context = patterns.getContext(); 1803 patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant, 1804 FoldConstantTranspose>(context, 1805 options.controlElementwiseOpsFusionFn); 1806 patterns.add<RemoveOutsDependency>(context); 1807 populateFoldReshapeOpsByExpansionPatterns(patterns, 1808 options.controlFoldingReshapesFn); 1809 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 1810 GenericOp::getCanonicalizationPatterns(patterns, context); 1811 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); 1812 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); 1813 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns( 1814 patterns); 1815 } 1816 1817 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { 1818 auto *context = patterns.getContext(); 1819 patterns.add<PushExpandingReshape>(context); 1820 } 1821 1822 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() { 1823 return std::make_unique<LinalgElementwiseOpFusionPass>(); 1824 } 1825 1826 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() { 1827 return std::make_unique<FoldReshapeOpsByLinearizationPass>(); 1828 } 1829