1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion on tensors operations pass. 10 // 11 //===----------------------------------------------------------------------===// 12 #include "PassDetail.h" 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/IR/AffineExpr.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Support/LLVM.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 26 using namespace mlir; 27 using namespace mlir::linalg; 28 29 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 30 /// the `producer` to use in the fused operation given the indexing map of the 31 /// result of the producer in the consumer. 32 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 33 OpOperand *producerOpOperand, AffineMap producerResultIndexMap, 34 AffineMap fusedConsumerArgIndexMap) { 35 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 36 // from consumer loop -> consumer arg tensor index/producer result tensor 37 // index. The fused loop is same as the consumer loop. For each producer arg 38 // the indexing map to be computed is a map from consumer loop -> producer 39 // arg tensor index. 40 // producerResultIndexMap is a map from producer loop -> tensor index. 41 // Compute the inverse to get map from tensor index -> producer loop. 42 // The inverse is a map from producer result tensor index -> producer loop. 43 AffineMap invProducerResultIndexMap = 44 inversePermutation(producerResultIndexMap); 45 assert(invProducerResultIndexMap && 46 "expected producer result indexig map to be invertible"); 47 48 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner()); 49 // argMap is a map from producer loop -> producer arg tensor index. 50 AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); 51 52 // Compose argMap with invProducerResultIndexMap to get a map from 53 // producer result tensor index -> producer arg tensor index. 54 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 55 56 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 57 // consumer loop/ fused loop -> producer arg tensor index. 58 return t1.compose(fusedConsumerArgIndexMap); 59 } 60 61 /// Conditions for elementwise fusion of generic operations. 62 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, 63 OpOperand *consumerOpOperand) { 64 // Producer and consumer must have tensor semantics. 65 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 66 return false; 67 68 // Verify that 69 // - the producer has all "parallel" iterator type. 70 if (producer.getNumParallelLoops() != producer.getNumLoops()) 71 return false; 72 73 // Only allow fusing the producer of an input operand for now. 74 // TODO: allow fusing the producer of an output operand. 75 if (!consumer.isInputTensor(consumerOpOperand)) 76 return false; 77 78 // Get the consumer index map. The number of results of the consumer index 79 // map must match the number of loops of the producer. 80 AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); 81 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 82 return false; 83 84 // Currently support only operations with single result. 85 if (producer.getNumOutputs() != 1) 86 return false; 87 88 // Finally the index_map for the result must be invertible. For now just 89 // verify it is a permutation. 90 AffineMap producerResultIndexMap = 91 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 92 if (!producerResultIndexMap.isPermutation()) 93 return false; 94 95 // Ensure that the fusion does not remove size information required to 96 // get the loop bounds. For non-reduction generics, this is trivially the 97 // case due to the output operand. For reductions, we need to check that after 98 // the fusion, each loop dimension has at least one input that defines it. 99 if ((consumer.getNumReductionLoops())) { 100 llvm::BitVector coveredDims(consumer.getNumLoops(), false); 101 102 auto addToCoveredDims = [&](AffineMap map) { 103 for (auto result : map.getResults()) 104 if (auto dimExpr = result.dyn_cast<AffineDimExpr>()) 105 coveredDims[dimExpr.getPosition()] = true; 106 }; 107 108 for (auto pair : 109 llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) { 110 Value operand = std::get<0>(pair); 111 if (operand == consumerOpOperand->get()) 112 continue; 113 AffineMap operandMap = std::get<1>(pair); 114 addToCoveredDims(operandMap); 115 } 116 117 for (OpOperand *operand : producer.getInputOperands()) { 118 AffineMap newIndexingMap = 119 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 120 operand, producerResultIndexMap, consumerIndexMap); 121 addToCoveredDims(newIndexingMap); 122 } 123 if (!coveredDims.all()) 124 return false; 125 } 126 127 return true; 128 } 129 130 /// Generate the region of the fused tensor operation. The region of the fused 131 /// op must be empty. 132 static void 133 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, 134 AffineMap consumerToProducerLoopsMap, 135 OpOperand *consumerOpOperand, 136 unsigned nloops) { 137 auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp()); 138 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 139 // Build the region of the fused op. 140 Block &producerBlock = producer->getRegion(0).front(); 141 Block &consumerBlock = consumer->getRegion(0).front(); 142 Block *fusedBlock = new Block(); 143 fusedOp.region().push_back(fusedBlock); 144 BlockAndValueMapping mapper; 145 OpBuilder::InsertionGuard guard(rewriter); 146 rewriter.setInsertionPointToStart(fusedBlock); 147 148 // 2. Add an index operation for every fused loop dimension and use the 149 // `consumerToProducerLoopsMap` to map the producer indices. 150 if (producer.hasIndexSemantics()) { 151 // Add an index operation for every fused loop dimension. 152 unsigned numFusedOpLoops = 153 std::max(producer.getNumLoops(), consumer.getNumLoops()); 154 SmallVector<Value> fusedIndices; 155 fusedIndices.reserve(numFusedOpLoops); 156 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops), 157 std::back_inserter(fusedIndices), [&](uint64_t dim) { 158 return rewriter.create<IndexOp>(producer.getLoc(), dim); 159 }); 160 for (IndexOp indexOp : 161 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) { 162 Value newIndex = rewriter.create<mlir::AffineApplyOp>( 163 producer.getLoc(), 164 consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices); 165 mapper.map(indexOp.getResult(), newIndex); 166 } 167 } 168 // TODO: allow fusing the producer of an output operand. 169 assert(consumer.isInputTensor(consumerOpOperand) && 170 "expected producer of input operand"); 171 // 3. Consumer input operands up to consumerIdx (exclusive). 172 for (BlockArgument bbArg : consumerBlock.getArguments().take_front( 173 consumerOpOperand->getOperandNumber())) // input assumption. 174 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 175 176 // Replacing consumerIdx requires getting the cloned, yielded, value from 177 // the (cloned) producer block. This happens in step 9. 178 179 // 4. Splice in producer's input operands. 180 for (BlockArgument bbArg : 181 producerBlock.getArguments().take_front(producer.getNumInputs())) 182 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 183 184 // 4.b. Producer output operand/map that is fused needs to be mapped to the 185 // producer bbArg if it is an "initTensor" (i.e. its value is actually read). 186 assert(producer->getNumResults() == 1 && "expected single result producer"); 187 if (producer.isInitTensor(producer.getOutputOperand(0))) { 188 BlockArgument bbArg = producerBlock.getArguments() 189 .drop_front(producer.getNumInputs()) 190 // TODO: bbArg index of 191 .front(); 192 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 193 } 194 // 5. Remaining consumer's input operands (drop past index `consumerIdx`). 195 for (BlockArgument bbArg : 196 consumerBlock.getArguments() 197 .take_front(consumer.getNumInputs()) 198 .drop_front(consumerOpOperand->getOperandNumber() + 1)) 199 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 200 // 6. All of consumer's output operands. 201 for (BlockArgument bbArg : 202 consumerBlock.getArguments().take_back(consumer.getNumOutputs())) 203 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 204 // 7. All of producer's output operands except the one fused. 205 // TODO: allow fusion of multi-result producers. 206 assert(producer->getNumResults() == 1 && "expected single result producer"); 207 208 // 8. Clone all producer operations except for the yield and index operations 209 // to the fused operation. 210 for (auto &op : producerBlock.without_terminator()) { 211 if (!isa<IndexOp>(op)) 212 rewriter.clone(op, mapper); 213 } 214 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just 215 // forward the yield operand. 216 auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator()); 217 // TODO: allow fusion of multi-result producers. 218 assert(producer->getNumResults() == 1 && "expected single result producer"); 219 unsigned producerResultNumber = 0; 220 Value replacement = 221 mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); 222 // Sanity checks, if replacement is not already in the mapper then it must be 223 // produced outside. 224 if (replacement == yieldOp.getOperand(producerResultNumber)) { 225 if (auto bb = replacement.dyn_cast<BlockArgument>()) 226 assert(bb.getOwner() != &producerBlock && 227 "yielded block argument must have been mapped"); 228 else 229 assert(!producer->isAncestor(replacement.getDefiningOp()) && 230 "yielded value must have been mapped"); 231 } 232 mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), 233 replacement); 234 // 10. Clone operations from the consumer to the fused op. 235 for (auto &op : consumerBlock.getOperations()) 236 rewriter.clone(op, mapper); 237 238 // Sanity checks. 239 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && 240 "Ill-formed GenericOp region"); 241 } 242 243 static Optional<SmallVector<Value>> 244 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, 245 const ControlElementwiseOpsFusionFn &controlFn, 246 PatternRewriter &rewriter) { 247 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 248 if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || 249 !controlFn(producer->getResult(0), *consumerOpOperand)) 250 return llvm::None; 251 252 // TODO: allow fusing the producer of an output operand. 253 assert(consumer.isInputTensor(consumerOpOperand) && 254 "expected producer of input operand"); 255 256 // Compute the fused operands list and indexing maps. 257 SmallVector<Value> fusedOperands; 258 SmallVector<AffineMap> fusedIndexMaps; 259 fusedOperands.reserve(producer->getNumOperands() + 260 consumer->getNumOperands()); 261 fusedIndexMaps.reserve(producer->getNumOperands() + 262 consumer->getNumOperands()); 263 // In the following, numbering matches that of `generateFusedTensorOpRegion`. 264 // 3. Consumer input operands/maps up to consumerIdx (exclusive). 265 SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands(); 266 SmallVector<OpOperand *>::iterator it = 267 llvm::find(consumerInputs, consumerOpOperand); 268 assert(it != consumerInputs.end() && "expected to find the consumer operand"); 269 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { 270 fusedOperands.push_back(opOperand->get()); 271 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 272 } 273 // 4. Splice in producer's input operands/maps. 274 assert(producer->getNumResults() == 1 && "expected single result producer"); 275 AffineMap producerResultIndexMap = 276 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 277 for (OpOperand *opOperand : producer.getInputOperands()) { 278 fusedOperands.push_back(opOperand->get()); 279 // Compute indexing maps for the producer args in the fused operation. 280 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 281 opOperand, producerResultIndexMap, 282 consumer.getTiedIndexingMap(consumerOpOperand)); 283 fusedIndexMaps.push_back(map); 284 } 285 // 4.b. Producer output operand/map that is fused needs to be passed if it is 286 // an "initTensor" (i.e. its value is actually read). 287 assert(producer->getNumResults() == 1 && "expected single result producer"); 288 if (producer.isInitTensor(producer.getOutputOperand(0))) { 289 fusedOperands.push_back(producer.getOutputOperand(0)->get()); 290 // Compute indexing maps for the producer args in the fused operation. 291 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 292 producer.getOutputOperand(0), producerResultIndexMap, 293 consumer.getTiedIndexingMap(consumerOpOperand)); 294 fusedIndexMaps.push_back(map); 295 } 296 // 5. Remaining consumer's input operands/maps (drop past index 297 // `consumerIdx`). 298 for (OpOperand *opOperand : 299 llvm::make_range(std::next(it), consumerInputs.end())) { 300 fusedOperands.push_back(opOperand->get()); 301 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 302 } 303 // 6. All of consumer's output operands (skip operands: added by the builder). 304 for (OpOperand *opOperand : consumer.getOutputOperands()) 305 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 306 // 7. All of producer's output operands/maps except the one fused. 307 // TODO: allow fusion of multi-result producers. 308 assert(producer->getNumResults() == 1 && "expected single result producer"); 309 310 // Generate the fused op. 311 SmallVector<Value> consumerOutputs = consumer.getOutputOperands(); 312 auto fusedOp = rewriter.create<GenericOp>( 313 consumer.getLoc(), consumer->getResultTypes(), 314 /*inputs=*/fusedOperands, 315 // TODO: handle outputs. 316 consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 317 consumer.iterator_types(), 318 /*doc=*/nullptr, 319 /*library_call=*/nullptr); 320 321 // Construct an AffineMap from consumer loops to producer loops. 322 // consumer loop -> tensor index 323 AffineMap consumerResultIndexMap = 324 consumer.getTiedIndexingMap(consumerOpOperand); 325 // tensor index -> producer loop 326 AffineMap invProducerResultIndexMap = 327 inversePermutation(producerResultIndexMap); 328 assert(invProducerResultIndexMap && 329 "expected producer result indexig map to be invertible"); 330 // consumer loop -> producer loop 331 AffineMap consumerToProducerLoopsMap = 332 invProducerResultIndexMap.compose(consumerResultIndexMap); 333 334 generateFusedElementwiseOpRegion(rewriter, fusedOp, 335 consumerToProducerLoopsMap, 336 consumerOpOperand, consumer.getNumLoops()); 337 return SmallVector<Value>(fusedOp->getResults()); 338 } 339 340 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` 341 /// provided, given the shape of the source tensor that corresponds to the 342 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions 343 /// are "row-major" ordered logically. 344 /// 345 /// For example: 346 /// 347 /// %0 = op ... : tensor<?x?x4x5xf32> 348 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` 349 /// 350 /// and reshape: 351 /// %1 = linalg.tensor_collapse_shape %0 [[0], [0, 1, 2]] : 352 /// tensor<?x?x4x5xf32> into tensor<?x?xf32> 353 /// 354 /// would be rewritten into: 355 /// %0 = op ... : tensor<?x?x4x5xf32> 356 /// with output index_map 357 /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` 358 template <typename TensorReshapeOp> 359 static AffineMap linearizeCollapsedDims(AffineMap sourceMap, 360 TensorReshapeOp reshapeOp) { 361 constexpr bool isExpanding = 362 std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value; 363 ArrayRef<int64_t> sourceShape = 364 (isExpanding ? reshapeOp.getResultType().getShape() 365 : reshapeOp.getSrcType().getShape()); 366 SmallVector<AffineExpr> resultExprs; 367 ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults(); 368 MLIRContext *context = sourceMap.getContext(); 369 370 // Compute the result exprs based on the reassociation maps. 371 for (auto &indices : reshapeOp.getReassociationIndices()) { 372 // Assume that they are in-order and contiguous (already checked in 373 // verifier). 374 assert(!indices.empty()); 375 SmallVector<int64_t> sizes; 376 SmallVector<AffineExpr> dimExprs; 377 for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()), 378 sourceExprs.slice(indices[0], indices.size()))) { 379 if (std::get<0>(en) == 1) 380 continue; 381 sizes.push_back(std::get<0>(en)); 382 dimExprs.push_back(std::get<1>(en)); 383 } 384 AffineExpr linearizedExpr = 385 makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); 386 resultExprs.push_back(linearizedExpr); 387 } 388 // The new affine map cannot drop unused dimension but some new symbols may 389 // have been added. Create a map with at least as many dimensions/symbols as 390 // the original affine map. 391 int64_t maxDim = -1; 392 int64_t maxSym = -1; 393 getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym); 394 unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims()); 395 unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols()); 396 return AffineMap::get(numDims, numSyms, resultExprs, context); 397 } 398 399 // TensorExpandShapeOp is fusable with its consumer (i.e. reshape as a 400 // producer). Fusing when operand has higher rank will require use of mods and 401 // divs in the indexing maps of the fused op which would make it non-invertible. 402 static bool isTensorReshapeOpFoldableByLinearization( 403 TensorExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) { 404 if (!asProducer) 405 return false; 406 return useIndexMap.isPermutation(); 407 } 408 409 // TensorCollapseShapeOp is fusable with its producer (i.e. reshape as a 410 // consumer). 411 static bool isTensorReshapeOpFoldableByLinearization( 412 TensorCollapseShapeOp collapseOp, AffineMap useIndexMap, bool asProducer) { 413 if (asProducer) 414 return false; 415 return useIndexMap.isPermutation(); 416 } 417 418 /// Check if the reshape operation is only expansion into/collapsing of 419 /// unit-dimension. 420 template <typename TensorReshapeOp> 421 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) { 422 constexpr bool isExpanding = 423 std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value; 424 ArrayRef<int64_t> expandedShape = 425 (isExpanding ? reshapeOp.getResultType().getShape() 426 : reshapeOp.getSrcType().getShape()); 427 for (auto &indices : reshapeOp.getReassociationIndices()) { 428 unsigned numUnitDims = 0; 429 for (int64_t position : indices) 430 if (expandedShape[position] == 1) 431 numUnitDims++; 432 if (numUnitDims != indices.size() - 1) 433 return false; 434 } 435 return true; 436 } 437 438 /// Conditions for folding a generic operation with a reshape op by expanding 439 /// the iteration space dimensionality for tensor operations. These are 440 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the 441 /// following fusion pattern. 442 /// 443 /// Consider 444 /// 445 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>) 446 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 447 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 448 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] 449 /// %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]] 450 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 451 /// 452 /// The reshape can be folded into the `genericOp` if its loop dimensionality 453 /// is increased to match the result (operand) of the tensor_expand_shape. 454 /// The indexing_map of the fused tensor in the `genericOp` and the 455 /// reassociation map helps compute the indexing maps of the modified op. 456 /// For the above example, based on the reassociation map it 457 /// can be concluded that 458 /// 459 /// - The loop used to access the first dimension of the fused tensor is split 460 /// into two. 461 /// - The loop used to access the second dimension of the fused tensor is kept 462 /// as is. 463 /// - The loop used to access the third dimension of the fused tensor is split 464 /// into three. 465 /// 466 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified 467 /// op, then 468 /// 469 /// d0 -> e0, e1 470 /// d1 -> e2, e3, e4 471 /// d2 -> e5 472 /// 473 /// substituting this, the generic op can be rewritten as 474 /// 475 /// %d = linalg.generic ins(%0, %1 : ) 476 /// indexing_maps = 477 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, 478 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, 479 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] 480 /// 481 /// Since operands to the linalg generic are now 5D, reshapes can be introduced 482 /// to make it consistent 483 /// 484 /// %0 = linalg.tensor_expand_shape %a [[0, 1, 2], [3, 4], [5]] 485 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 486 /// %1 = linalg.tensor_expand_shape %b [[0, 1, 2], [3]] 487 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 488 /// 489 /// The added reshapes are again expanding patterns, so they will get fused 490 /// with its producers if possible. 491 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, 492 OpOperand *fusableOpOperand) { 493 // Is fusable only if: 494 // - All the indexing maps for operands and results are projected 495 // permutations. 496 // - The fused tensor is not a scalar. 497 // - All the loops are parallel loops. 498 return genericOp.hasTensorSemantics() && 499 llvm::all_of(genericOp.indexing_maps().getValue(), 500 [](Attribute attr) { 501 return attr.cast<AffineMapAttr>() 502 .getValue() 503 .isProjectedPermutation(); 504 }) && 505 genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && 506 llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { 507 return attr.cast<StringAttr>().getValue() == 508 getParallelIteratorTypeName(); 509 }); 510 } 511 512 namespace { 513 /// Information needed to expand a generic operation to fold the reshape with 514 /// it. 515 class ExpansionInfo { 516 public: 517 // Computes the mapping from original dimensions of the op to the dimensions 518 // of the expanded op given the `indexingMap` of the fused operand/result of 519 // the generic op, the `reassocationMaps` of the reshape op and the shape of 520 // the expanded op. 521 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, 522 ArrayRef<AffineMap> reassociationMaps, 523 ArrayRef<int64_t> expandedShape, 524 PatternRewriter &rewriter); 525 unsigned getOrigOpNumDims() const { return reassociation.size(); } 526 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } 527 ReassociationIndicesRef getExpandedDims(unsigned i) const { 528 return reassociation[i]; 529 } 530 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { 531 return expandedShapeMap[i]; 532 } 533 534 private: 535 /// Reassociation from the dimensions in the original operation to the 536 /// dimension of the expanded operation. 537 SmallVector<ReassociationIndices> reassociation; 538 /// Mapping from extent of loops in the original operation, to the extent of 539 /// loops in the expanded operation. 540 SmallVector<SmallVector<int64_t>> expandedShapeMap; 541 unsigned expandedOpNumDims; 542 }; 543 } // namespace 544 545 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, 546 OpOperand *fusableOpOperand, 547 ArrayRef<AffineMap> reassociationMaps, 548 ArrayRef<int64_t> expandedShape, 549 PatternRewriter &rewriter) { 550 if (reassociationMaps.empty()) 551 return failure(); 552 AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); 553 554 Optional<SmallVector<int64_t, 4>> originalLoopRange = 555 linalgOp.getStaticLoopRanges(); 556 if (!originalLoopRange) 557 return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range"); 558 559 reassociation.clear(); 560 expandedShapeMap.clear(); 561 // Compute the number of dimension in the expanded op that correspond to each 562 // dimension of the original op. 563 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1); 564 expandedShapeMap.resize(fusedIndexMap.getNumDims()); 565 for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { 566 unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition(); 567 AffineMap foldedDims = reassociationMaps[resultExpr.index()]; 568 numExpandedDims[pos] = foldedDims.getNumResults(); 569 ArrayRef<int64_t> shape = 570 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); 571 expandedShapeMap[pos].assign(shape.begin(), shape.end()); 572 } 573 // The remaining dimensions remain the same. 574 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims())) 575 if (expandedShapeMap[i].empty()) 576 expandedShapeMap[i] = {(*originalLoopRange)[i]}; 577 578 // Compute reassociation map from the original op to the expanded op. 579 unsigned sum = 0; 580 reassociation.reserve(fusedIndexMap.getNumDims()); 581 for (auto numFoldedDim : llvm::enumerate(numExpandedDims)) { 582 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value()); 583 reassociation.emplace_back(seq.begin(), seq.end()); 584 sum += numFoldedDim.value(); 585 } 586 expandedOpNumDims = sum; 587 return success(); 588 } 589 590 /// Epanding the body of a linalg operation requires adaptations of the accessed 591 /// loop indices. Specifically, access of indices in the original operation need 592 /// to be replaced with linearizations of indices in the expanded op. That 593 /// requires the shape of the expanded dimensions to be static (at least all but 594 /// the most significant). For now check that these are all statically sized. 595 /// Note that this could be extended to handle dynamic case, but the 596 /// implementation below uses `affine.apply` which seems to have issues when the 597 /// shapes are not static. 598 LogicalResult isGenericOpExpandable(GenericOp genericOp, 599 const ExpansionInfo &expansionInfo, 600 PatternRewriter &rewriter) { 601 if (!genericOp.hasIndexSemantics()) 602 return success(); 603 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 604 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 605 if (expandedShape.size() == 1) 606 continue; 607 for (int64_t shape : expandedShape.drop_front()) { 608 if (ShapedType::isDynamic(shape)) { 609 return rewriter.notifyMatchFailure( 610 genericOp, "cannot expand due to index semantics and dynamic dims"); 611 } 612 } 613 } 614 return success(); 615 } 616 617 /// Return the indexing map to use in the expanded op for a given the 618 /// `indexingMap` of the original operation. 619 static AffineMap 620 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, 621 const ExpansionInfo &expansionInfo) { 622 SmallVector<AffineExpr> newExprs; 623 for (AffineExpr expr : indexingMap.getResults()) { 624 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 625 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>( 626 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { 627 return builder.getAffineDimExpr(static_cast<unsigned>(v)); 628 })); 629 newExprs.append(expandedExprs.begin(), expandedExprs.end()); 630 } 631 return AffineMap::get(expansionInfo.getExpandedOpNumDims(), 632 indexingMap.getNumSymbols(), newExprs, 633 builder.getContext()); 634 } 635 636 /// Return the type of the operand/result to use in the expanded op given the 637 /// type in the original op. 638 static RankedTensorType getExpandedType(RankedTensorType originalType, 639 AffineMap indexingMap, 640 const ExpansionInfo &expansionInfo) { 641 SmallVector<int64_t> expandedShape; 642 for (AffineExpr expr : indexingMap.getResults()) { 643 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 644 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); 645 expandedShape.append(dimExpansion.begin(), dimExpansion.end()); 646 } 647 return RankedTensorType::get(expandedShape, originalType.getElementType()); 648 } 649 650 /// Returns the reassociation maps to use in the `linalg.tensor_expand_shape` 651 /// operation to convert the operands of the original operation to operands of 652 /// the expanded operation. The same method is used to compute the 653 /// `linalg.tensor_collapse_shape` used to collapse the result of the expanded 654 /// op to get the value that can replace all uses of the results of the original 655 /// op. 656 static SmallVector<ReassociationIndices> 657 getReassociationForExpansion(AffineMap indexingMap, 658 const ExpansionInfo &expansionInfo) { 659 SmallVector<ReassociationIndices> reassociation; 660 unsigned numReshapeDims = 0; 661 for (AffineExpr expr : indexingMap.getResults()) { 662 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 663 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); 664 SmallVector<int64_t, 2> indices = llvm::to_vector<2>( 665 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims)); 666 reassociation.emplace_back(std::move(indices)); 667 numReshapeDims += numExpandedDims; 668 } 669 return reassociation; 670 } 671 672 /// Update the body of an expanded linalg operation having index semantics. The 673 /// indices of the original operation need to be recovered by linearizing the 674 /// indices of the correspoding dimensions of the expanded operation. For now it 675 /// is assumed that the shapes of the expanded operation needed for 676 /// linearization are static. 677 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, 678 Location loc, Region &fusedRegion, 679 const ExpansionInfo &expansionInfo) { 680 // Replace the original indices by the linearization of the expanded indices. 681 for (IndexOp indexOp : 682 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) { 683 ArrayRef<int64_t> expandedDims = 684 expansionInfo.getExpandedDims(indexOp.dim()); 685 assert(!expandedDims.empty() && "expected valid expansion info"); 686 687 // Skip index operations that are not affected by the expansion. 688 if (expandedDims.size() == 1 && 689 expandedDims.front() == (int64_t)indexOp.dim()) 690 continue; 691 692 // Linearize the expanded indices of the original index dimension. 693 OpBuilder::InsertionGuard guard(rewriter); 694 rewriter.setInsertionPointAfter(indexOp); 695 ArrayRef<int64_t> expandedDimsShape = 696 expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front(); 697 SmallVector<Value> expandedIndices; 698 expandedIndices.reserve(expandedDims.size() - 1); 699 llvm::transform( 700 expandedDims.drop_front(), std::back_inserter(expandedIndices), 701 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); 702 Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front()); 703 for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { 704 assert(!ShapedType::isDynamic(std::get<0>(it))); 705 AffineExpr idx, acc; 706 bindDims(rewriter.getContext(), idx, acc); 707 newIndex = rewriter.create<AffineApplyOp>( 708 indexOp.getLoc(), idx + acc * std::get<0>(it), 709 ValueRange{std::get<1>(it), newIndex}); 710 } 711 rewriter.replaceOp(indexOp, newIndex); 712 } 713 } 714 715 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op 716 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes 717 /// that those conditions have been satisfied. 718 static Optional<SmallVector<Value>> 719 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, 720 OpOperand *fusableOpOperand, 721 PatternRewriter &rewriter) { 722 assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) && 723 "preconditions for fuse operation failed"); 724 // Check if reshape is expanding or collapsing. 725 auto expandingReshapeOp = dyn_cast<TensorExpandShapeOp>(*reshapeOp); 726 auto collapsingReshapeOp = dyn_cast<TensorCollapseShapeOp>(*reshapeOp); 727 bool isExpanding = (expandingReshapeOp != nullptr); 728 RankedTensorType expandedType = isExpanding 729 ? expandingReshapeOp.getResultType() 730 : collapsingReshapeOp.getSrcType(); 731 732 ExpansionInfo expansionInfo; 733 if (failed(expansionInfo.compute( 734 genericOp, fusableOpOperand, 735 isExpanding ? expandingReshapeOp.getReassociationMaps() 736 : collapsingReshapeOp.getReassociationMaps(), 737 expandedType.getShape(), rewriter))) 738 return llvm::None; 739 740 if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) 741 return llvm::None; 742 743 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( 744 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) { 745 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); 746 })); 747 748 SmallVector<Value> expandedOpOperands; 749 expandedOpOperands.reserve(genericOp.getNumInputs()); 750 for (OpOperand *opOperand : genericOp.getInputOperands()) { 751 if (opOperand == fusableOpOperand) { 752 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() 753 : collapsingReshapeOp.src()); 754 continue; 755 } 756 if (genericOp.isInputTensor(opOperand)) { 757 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 758 RankedTensorType expandedOperandType = 759 getExpandedType(opOperand->get().getType().cast<RankedTensorType>(), 760 indexingMap, expansionInfo); 761 if (expandedOperandType != opOperand->get().getType()) { 762 // Reshape the operand to get the right type. 763 SmallVector<ReassociationIndices> reassociation = 764 getReassociationForExpansion(indexingMap, expansionInfo); 765 expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>( 766 genericOp.getLoc(), expandedOperandType, opOperand->get(), 767 reassociation)); 768 continue; 769 } 770 } 771 expandedOpOperands.push_back(opOperand->get()); 772 } 773 774 Location loc = genericOp.getLoc(); 775 SmallVector<Value> outputs; 776 for (OpOperand *opOperand : genericOp.getOutputOperands()) { 777 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 778 RankedTensorType expandedOutputType = 779 getExpandedType(opOperand->get().getType().cast<RankedTensorType>(), 780 indexingMap, expansionInfo); 781 if (expandedOutputType != opOperand->get().getType()) { 782 SmallVector<ReassociationIndices> reassociation = 783 getReassociationForExpansion(indexingMap, expansionInfo); 784 outputs.push_back(rewriter.create<TensorExpandShapeOp>( 785 genericOp.getLoc(), expandedOutputType, opOperand->get(), 786 reassociation)); 787 } 788 } 789 790 // The iterator types of the expanded op are all parallel. 791 SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(), 792 getParallelIteratorTypeName()); 793 794 TypeRange resultTypes = ValueRange(outputs).getTypes(); 795 auto fusedOp = 796 rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes, 797 /*inputs=*/expandedOpOperands, outputs, 798 expandedOpIndexingMaps, iteratorTypes); 799 Region &fusedRegion = fusedOp->getRegion(0); 800 Region &originalRegion = genericOp->getRegion(0); 801 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); 802 803 // Update the index accesses after the expansion. 804 updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); 805 806 // Reshape the result values to their original shape if this is a collapsing 807 // reshape folded into its consumer. 808 SmallVector<Value> resultVals; 809 for (OpResult opResult : genericOp->getOpResults()) { 810 int64_t resultNumber = opResult.getResultNumber(); 811 if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { 812 SmallVector<ReassociationIndices> reassociation = 813 getReassociationForExpansion( 814 genericOp.getTiedIndexingMap( 815 genericOp.getOutputOperand(resultNumber)), 816 expansionInfo); 817 resultVals.push_back(rewriter.create<TensorCollapseShapeOp>( 818 genericOp.getLoc(), opResult.getType(), 819 fusedOp->getResult(resultNumber), reassociation)); 820 } else { 821 resultVals.push_back(fusedOp->getResult(resultNumber)); 822 } 823 } 824 // Assuming a single result. 825 return resultVals; 826 } 827 828 namespace { 829 830 /// Pattern to fold tensor_expand_shape op with its consumer by using the source 831 /// of the reshape op as the operand in the consumer (instead of the result of 832 /// the tensor_collapse_shape). The corresponding index map in the consumer 833 /// needs to be modified to linearize the folded dimension. 834 /// 835 /// For example, 836 /// 837 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 838 /// %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2], [3]] 839 /// tensor<?x?x?xf32> into tensor<?x?x4x?xf32> 840 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } 841 /// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ... 842 /// -> tensor<?x?x4x?xf32> 843 /// 844 /// can be folded into 845 /// 846 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> 847 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 848 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } 849 /// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ... 850 /// -> tensor<?x?x4x?xf32> 851 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 852 struct FoldProducerReshapeOpByLinearization 853 : public OpRewritePattern<GenericOp> { 854 using OpRewritePattern<GenericOp>::OpRewritePattern; 855 856 LogicalResult matchAndRewrite(GenericOp genericOp, 857 PatternRewriter &rewriter) const override { 858 if (!genericOp.hasTensorSemantics()) 859 return failure(); 860 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 861 for (auto en : llvm::enumerate(inputOperands)) { 862 auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>(); 863 if (!reshapeOp) 864 continue; 865 866 if (!isTensorReshapeOpFoldableByLinearization( 867 reshapeOp, genericOp.getTiedIndexingMap(en.value()), 868 /*asProducer =*/true) || 869 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 870 continue; 871 872 // Compute the fused operands list, 873 SmallVector<Value> fusedOperands = genericOp.getInputOperands(); 874 fusedOperands[en.index()] = reshapeOp.src(); 875 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 876 llvm::append_range(fusedOperands, outputOperands); 877 878 // Compute indexing_maps for the fused operation. The indexing_maps for 879 // the operands of the consumers that arent fused are the same. 880 SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps(); 881 882 // Accepted consumer maps are either identity or permutation. 883 auto invMap = inversePermutation(fusedIndexMaps[en.index()]); 884 885 // Compute the indexing map to use for the result of the producer. 886 AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp); 887 // The modified map cannot have symbols. 888 if (modifiedMap.getNumSymbols()) 889 return failure(); 890 for (AffineExpr expr : modifiedMap.getResults()) { 891 if (!expr.isPureAffine()) 892 return failure(); 893 } 894 fusedIndexMaps[en.index()] = modifiedMap; 895 896 // Further check that the resulting index maps can be fused and 897 // inverted. Without this the resultant op is not legal. 898 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 899 return rewriter.notifyMatchFailure( 900 genericOp, "fused op loop bound computation failed"); 901 } 902 903 rewriter.startRootUpdate(genericOp); 904 genericOp->setOperands(fusedOperands); 905 genericOp.indexing_mapsAttr( 906 rewriter.getAffineMapArrayAttr(fusedIndexMaps)); 907 rewriter.finalizeRootUpdate(genericOp); 908 return success(); 909 } 910 return failure(); 911 } 912 }; 913 914 static SmallVector<ReassociationIndices> 915 getReassociationIndices(ArrayRef<AffineMap> maps) { 916 SmallVector<ReassociationIndices> reassociation; 917 for (AffineMap map : maps) { 918 ReassociationIndices indices; 919 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 920 unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition(); 921 indices.push_back(pos); 922 } 923 reassociation.push_back(indices); 924 } 925 return reassociation; 926 } 927 928 /// Pattern to move rank reducing reshape after an elementwise linalg generic 929 /// op. This is useful to expose more fusion opportunities between named ops and 930 /// generic ops. This can only be done if there is no broadcast or permuation 931 /// within the dimensions we need to merge. 932 /// 933 /// For example, 934 /// 935 /// %0 = linalg.tensor_expand_shape %A [[0, 1], [2]] 936 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 937 /// %2 = linalg.generic {indexing_maps = [ 938 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 939 /// affine_map<(d0, d1, d2) -> (d2)>, 940 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = 941 /// ["parallel", "parallel", "parallel"]} { 942 /// } -> tensor<112x112x16xf32> 943 /// 944 /// into 945 /// 946 /// %2 = linalg.generic {indexing_maps = [ 947 /// affine_map<(d0, d1) -> (d0, d1)>, 948 /// affine_map<(d0, d1) -> (d1)>, 949 /// affine_map<(d0, d1) -> (d0, d1)>], 950 /// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 951 /// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) { 952 /// } -> tensor<12544x16xf32> 953 /// %3 = linalg.tensor_expand_shape %2 [[0, 1], [2]] 954 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 955 struct PushExpandingReshape : public OpRewritePattern<GenericOp> { 956 using OpRewritePattern<GenericOp>::OpRewritePattern; 957 958 LogicalResult matchAndRewrite(GenericOp genericOp, 959 PatternRewriter &rewriter) const override { 960 // Only apply to elementwise linalg on tensor. 961 if (!genericOp.hasTensorSemantics() || 962 genericOp.getNumParallelLoops() != genericOp.getNumLoops()) 963 return failure(); 964 // Only support identity output maps. It could be extended to permuations if 965 // needed. 966 if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) { 967 return !genericOp.getTiedIndexingMap(opOperand).isIdentity(); 968 })) 969 return failure(); 970 int64_t destRank = genericOp.getNumParallelLoops(); 971 SmallVector<Value> newOperands = genericOp.getInputOperands(); 972 TensorExpandShapeOp reshapeFound; 973 // 1. Look for tensor_expand_shape operands and figure out save the 974 // dimensions merged. 975 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 976 for (auto en : llvm::enumerate(inputOperands)) { 977 auto reshapeOp = 978 en.value()->get().template getDefiningOp<TensorExpandShapeOp>(); 979 if (!reshapeOp) 980 continue; 981 // TODO: We could support non-identity map as long as the merged 982 // dimensions are still contiguous. 983 if (!genericOp.getTiedIndexingMap(en.value()).isIdentity()) 984 continue; 985 if (reshapeFound) { 986 // Only support a second reshape op if it has the same reassociate maps. 987 if (reshapeFound.getReassociationMaps() == 988 reshapeOp.getReassociationMaps()) 989 newOperands[en.index()] = reshapeOp.src(); 990 continue; 991 } 992 reshapeFound = reshapeOp; 993 newOperands[en.index()] = reshapeOp.src(); 994 } 995 if (!reshapeFound) 996 return failure(); 997 998 // Calculate the reassociation indices and rassociated reverse map. 999 SmallVector<ReassociationIndices> reassociation = 1000 getReassociationIndices(reshapeFound.getReassociationMaps()); 1001 SmallVector<unsigned> remap(destRank); 1002 for (auto &indices : llvm::enumerate(reassociation)) { 1003 for (int64_t index : indices.value()) { 1004 remap[index] = indices.index(); 1005 } 1006 } 1007 // 2. Verify that we can merge the dimensions in the linalg and that we 1008 // don't need to create new reshapes operands. Inserting new reshape 1009 // operands would defeat the purpose of the transformation. 1010 for (auto en : llvm::enumerate(inputOperands)) { 1011 if (en.value()->get() == newOperands[en.index()]) { 1012 AffineMap map = genericOp.getTiedIndexingMap(en.value()); 1013 for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { 1014 if (reassociation[remap[map.getDimPosition(i)]].size() > 1) 1015 return failure(); 1016 } 1017 } 1018 } 1019 1020 // 3. Calculate the affine map remapping and the reassociation to apply to 1021 // output tensors. 1022 SmallVector<AffineMap> newMaps; 1023 unsigned newRank = reassociation.size(); 1024 for (auto map : genericOp.getIndexingMaps()) { 1025 SmallVector<AffineExpr> newExprs; 1026 for (auto expr : map.getResults()) { 1027 unsigned position = expr.template cast<AffineDimExpr>().getPosition(); 1028 // Skip dimension merged except for the last of the group. 1029 if (reassociation[remap[position]].back() == position) { 1030 newExprs.push_back( 1031 getAffineDimExpr(remap[position], genericOp.getContext())); 1032 } 1033 } 1034 newMaps.push_back( 1035 AffineMap::get(newRank, 0, newExprs, genericOp.getContext())); 1036 } 1037 1038 // 4. Reshape the output tensors. 1039 SmallVector<Value> newOutputs; 1040 SmallVector<Type> newOutputTypes; 1041 for (auto output : genericOp.outputs()) { 1042 auto newOutputType = RankedTensorType::get( 1043 reshapeFound.getSrcType().getShape(), 1044 output.getType().template cast<RankedTensorType>().getElementType()); 1045 Value newOutput = rewriter.create<TensorCollapseShapeOp>( 1046 genericOp->getLoc(), newOutputType, output, reassociation); 1047 newOutputTypes.push_back(newOutputType); 1048 newOutputs.push_back(newOutput); 1049 } 1050 // 5. Create a new generic op with lowerer rank. 1051 SmallVector<StringRef> iteratorTypes(newRank, 1052 getParallelIteratorTypeName()); 1053 auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes, 1054 newOperands, newOutputs, newMaps, 1055 iteratorTypes); 1056 rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), 1057 newOp.region().begin()); 1058 // 6. Reshape the so that the type matches the uses. 1059 SmallVector<Value> newResults; 1060 for (auto result : llvm::enumerate(newOp->getResults())) { 1061 newResults.push_back(rewriter.create<TensorExpandShapeOp>( 1062 genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()], 1063 result.value(), reassociation)); 1064 } 1065 rewriter.replaceOp(genericOp, newResults); 1066 return success(); 1067 } 1068 }; 1069 1070 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op, 1071 /// when the reshape op is collapsing dimensions. The dimensionality of the loop 1072 /// in the consumer is expanded. 1073 class FoldWithProducerReshapeOpByExpansion 1074 : public OpRewritePattern<GenericOp> { 1075 public: 1076 FoldWithProducerReshapeOpByExpansion( 1077 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1078 PatternBenefit benefit = 1) 1079 : OpRewritePattern<GenericOp>(context, benefit), 1080 controlFoldingReshapes(foldReshapes) {} 1081 1082 LogicalResult matchAndRewrite(GenericOp genericOp, 1083 PatternRewriter &rewriter) const override { 1084 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 1085 TensorCollapseShapeOp reshapeOp = 1086 opOperand->get().getDefiningOp<TensorCollapseShapeOp>(); 1087 if (!reshapeOp) 1088 continue; 1089 // Fold only if 1090 // - The tensor reshape op is folding. 1091 // - All constraints of fusing with reshape by expansion are met. 1092 if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || 1093 (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) 1094 continue; 1095 1096 Optional<SmallVector<Value>> replacementValues = 1097 fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); 1098 if (!replacementValues) 1099 return failure(); 1100 rewriter.replaceOp(genericOp, replacementValues.getValue()); 1101 return success(); 1102 } 1103 return failure(); 1104 } 1105 1106 private: 1107 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1108 }; 1109 1110 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its 1111 /// producer. The corresponding index map in the consumer needs to be modified 1112 /// to linearize the folded dimension. 1113 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 1114 struct FoldConsumerReshapeOpByLinearization 1115 : public OpRewritePattern<TensorReshapeOp> { 1116 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 1117 1118 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 1119 PatternRewriter &rewriter) const override { 1120 GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>(); 1121 if (!producer || !producer.hasTensorSemantics() || 1122 producer.getNumOutputs() != 1 || 1123 !isTensorReshapeOpFoldableByLinearization( 1124 reshapeOp, 1125 producer.getTiedIndexingMap(producer.getOutputOperand(0)), 1126 /*asProducer =*/false) || 1127 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 1128 return failure(); 1129 // The indexing_maps for the operands of the fused operation are same as 1130 // those for the operands of the producer. 1131 SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps(); 1132 1133 auto invMap = inversePermutation( 1134 producer.getTiedIndexingMap(producer.getOutputOperand(0))); 1135 1136 // Compute the indexing map to use for the operand of the producer. 1137 AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp); 1138 for (AffineExpr expr : modifiedMap.getResults()) { 1139 if (!expr.isPureAffine()) { 1140 return rewriter.notifyMatchFailure( 1141 producer, "fused op indexing map is not affine"); 1142 } 1143 } 1144 fusedIndexMaps.back() = modifiedMap; 1145 1146 // Further check that the resulting index maps can be fused and 1147 // inverted. Without this the resultant op is not legal. 1148 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1149 return rewriter.notifyMatchFailure( 1150 producer, "fused op loop bound computation failed"); 1151 } 1152 1153 Location loc = producer.getLoc(); 1154 SmallVector<Value> inputOperands = producer.getInputOperands(); 1155 Value output = rewriter.create<TensorReshapeOp>( 1156 loc, producer.getOutputOperand(0)->get(), 1157 reshapeOp.getReassociationExprs()); 1158 auto fusedOp = rewriter.create<GenericOp>( 1159 loc, reshapeOp.getResultType(), 1160 /*inputs=*/inputOperands, 1161 // TODO: handle outputs. 1162 /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1163 producer.iterator_types(), 1164 /*doc=*/nullptr, 1165 /*library_call=*/nullptr); 1166 auto &fusedRegion = fusedOp->getRegion(0); 1167 rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion, 1168 fusedRegion.begin()); 1169 rewriter.replaceOp(reshapeOp, fusedOp->getResults()); 1170 return success(); 1171 } 1172 }; 1173 1174 /// Pattern to fold a tensor_expand_shape op with its producer generic op 1175 /// by expanding the dimensionality of the loop in the producer op. 1176 struct FoldReshapeWithGenericOpByExpansion 1177 : public OpRewritePattern<TensorExpandShapeOp> { 1178 1179 FoldReshapeWithGenericOpByExpansion( 1180 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1181 PatternBenefit benefit = 1) 1182 : OpRewritePattern<TensorExpandShapeOp>(context, benefit), 1183 controlFoldingReshapes(foldReshapes) {} 1184 1185 LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp, 1186 PatternRewriter &rewriter) const override { 1187 // Fold only if all constraints of fusing with reshape by expansion are met. 1188 GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>(); 1189 if (!producer || producer.getNumOutputs() != 1 || 1190 !isFusableWithReshapeByDimExpansion(producer, 1191 producer.getOutputOperand(0)) || 1192 !controlFoldingReshapes(producer->getResult(0), 1193 reshapeOp->getOpOperand(0))) 1194 return failure(); 1195 Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion( 1196 producer, reshapeOp, producer.getOutputOperand(0), rewriter); 1197 if (!replacementValues) 1198 return failure(); 1199 rewriter.replaceOp(reshapeOp, replacementValues.getValue()); 1200 return success(); 1201 } 1202 1203 private: 1204 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1205 }; 1206 1207 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not 1208 /// handle cases where the constant is not single-valued. 1209 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> { 1210 public: 1211 FoldScalarOrSplatConstant(MLIRContext *context, 1212 ControlElementwiseOpsFusionFn &fun, 1213 PatternBenefit benefit = 1) 1214 : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {} 1215 1216 LogicalResult matchAndRewrite(GenericOp genericOp, 1217 PatternRewriter &rewriter) const override { 1218 if (!genericOp.hasTensorSemantics()) 1219 return failure(); 1220 for (OpOperand *opOperand : genericOp.getInputOperands()) { 1221 Operation *def = opOperand->get().getDefiningOp(); 1222 Attribute constantAttr; 1223 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { 1224 { 1225 DenseElementsAttr splatAttr; 1226 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) && 1227 splatAttr.isSplat() && 1228 splatAttr.getType().getElementType().isIntOrFloat()) { 1229 constantAttr = splatAttr.getSplatValue<Attribute>(); 1230 return true; 1231 } 1232 } 1233 { 1234 IntegerAttr intAttr; 1235 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) { 1236 constantAttr = intAttr; 1237 return true; 1238 } 1239 } 1240 { 1241 FloatAttr floatAttr; 1242 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) { 1243 constantAttr = floatAttr; 1244 return true; 1245 } 1246 } 1247 return false; 1248 }; 1249 1250 auto resultValue = opOperand->get().dyn_cast<OpResult>(); 1251 if (!def || !resultValue || !isScalarOrSplatConstantOp(def) || 1252 !controlFn(resultValue, *opOperand)) 1253 continue; 1254 1255 // The operands and the indexing_maps of the fused operation the same as 1256 // the operands and indexing_maps of the generic operations with the 1257 // values at the constant index dropped. 1258 SmallVector<AffineMap> fusedIndexMaps; 1259 SmallVector<Value> fusedOperands; 1260 SmallVector<Location> fusedLocs{genericOp.getLoc()}; 1261 fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); 1262 fusedOperands.reserve(genericOp.getNumInputs()); 1263 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); 1264 for (OpOperand *inputOperand : genericOp.getInputOperands()) { 1265 if (inputOperand == opOperand) 1266 continue; 1267 Value inputValue = inputOperand->get(); 1268 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); 1269 fusedOperands.push_back(inputValue); 1270 fusedLocs.push_back(inputValue.getLoc()); 1271 } 1272 for (OpOperand *outputOperand : genericOp.getOutputOperands()) 1273 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); 1274 1275 // Check if the operation shapes to loops map is computable. 1276 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1277 return rewriter.notifyMatchFailure( 1278 genericOp, "fused op loop bound computation failed"); 1279 } 1280 1281 // Create a constant scalar value from the splat constant. 1282 Value scalarConstant = rewriter.create<arith::ConstantOp>( 1283 def->getLoc(), constantAttr, constantAttr.getType()); 1284 1285 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 1286 auto fusedOp = rewriter.create<GenericOp>( 1287 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), 1288 /*inputs=*/fusedOperands, 1289 /*outputs=*/outputOperands, 1290 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1291 genericOp.iterator_types(), 1292 /*doc=*/nullptr, 1293 /*library_call=*/nullptr); 1294 1295 // Map the block argument corresponding to the replaced argument with the 1296 // scalar constant. 1297 Region ®ion = genericOp->getRegion(0); 1298 Block &entryBlock = *region.begin(); 1299 BlockAndValueMapping mapping; 1300 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), 1301 scalarConstant); 1302 Region &fusedRegion = fusedOp->getRegion(0); 1303 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), 1304 mapping); 1305 rewriter.replaceOp(genericOp, fusedOp->getResults()); 1306 return success(); 1307 } 1308 return failure(); 1309 } 1310 1311 private: 1312 ControlElementwiseOpsFusionFn controlFn; 1313 }; 1314 1315 /// Base class for constant folding linalg.generic ops with N inputs, 1 output, 1316 /// and permutation indexing maps. 1317 /// 1318 /// `ConcreteType` should provide methods with signatures 1319 /// 1320 /// ```c++ 1321 /// bool matchIndexingMaps(GenericOp genericOp) const; 1322 /// RegionComputationFn getRegionComputeFn(GenericOp) const; 1323 /// ``` 1324 /// 1325 /// The latter inspects the region and returns the computation inside as a 1326 /// functor. The functor will be invoked with constant elements for all inputs 1327 /// and should return the corresponding computea constant element for output. 1328 template <typename ConcreteType> 1329 class FoldConstantBase : public OpRewritePattern<GenericOp> { 1330 public: 1331 struct APIntOrFloat { 1332 Optional<APInt> apInt; 1333 Optional<APFloat> apFloat; 1334 }; 1335 struct APIntOrFloatArray { 1336 SmallVector<APInt> apInts; 1337 SmallVector<APFloat> apFloats; 1338 }; 1339 using RegionComputationFn = 1340 std::function<APIntOrFloat(const APIntOrFloatArray &)>; 1341 1342 FoldConstantBase(MLIRContext *context, 1343 const ControlElementwiseOpsFusionFn &controlFn, 1344 PatternBenefit benefit = 1) 1345 : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {} 1346 1347 LogicalResult matchAndRewrite(GenericOp genericOp, 1348 PatternRewriter &rewriter) const override { 1349 if (genericOp.hasBufferSemantics()) 1350 return failure(); 1351 1352 // Only support ops generating one output for now. 1353 if (genericOp.getNumOutputs() != 1) 1354 return failure(); 1355 1356 auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>(); 1357 // Require the output types to be static give we are generating constants. 1358 if (!outputType || !outputType.hasStaticShape()) 1359 return failure(); 1360 1361 if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) { 1362 return operand->get().getType().isa<ShapedType>(); 1363 })) 1364 return failure(); 1365 1366 // Make sure all element types are the same. 1367 auto getOperandElementType = [](OpOperand *operand) { 1368 return operand->get().getType().cast<ShapedType>().getElementType(); 1369 }; 1370 if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(), 1371 getOperandElementType))) 1372 return failure(); 1373 1374 // We can only handle the case where we have int/float elements. 1375 auto elementType = outputType.getElementType(); 1376 if (!elementType.isIntOrFloat()) 1377 return failure(); 1378 1379 // Require all indexing maps to be permutations for now. This is common and 1380 // it simplifies input/output access greatly: we can do the data shuffling 1381 // entirely in the compiler, without needing to turn all indices into 1382 // Values, and then do affine apply on them, and then match back the 1383 // constant again. 1384 if (!llvm::all_of(genericOp.getIndexingMaps(), 1385 [](AffineMap map) { return map.isPermutation(); })) 1386 return failure(); 1387 1388 for (OpOperand *operand : genericOp.getOutputOperands()) { 1389 if (genericOp.payloadUsesValueFromOperand(operand)) 1390 return failure(); 1391 } 1392 1393 // Further check the indexing maps are okay for the ConcreteType. 1394 if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp)) 1395 return failure(); 1396 1397 // Defer to the concrete type to check the region and discover the 1398 // computation inside. 1399 RegionComputationFn computeFn = 1400 static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp); 1401 if (!computeFn) 1402 return failure(); 1403 1404 // All inputs should be constants. 1405 int numInputs = genericOp.getNumInputs(); 1406 SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs); 1407 for (auto operand : llvm::enumerate(genericOp.getInputOperands())) { 1408 if (!matchPattern(operand.value()->get(), 1409 m_Constant(&inputValues[operand.index()]))) 1410 return failure(); 1411 } 1412 1413 // Identified this as a potential candidate for folding. Now check the 1414 // policy to see whether we are allowed to proceed. 1415 for (int i = 0; i < numInputs; ++i) { 1416 OpOperand *consumer = genericOp.getInputOperand(i); 1417 OpResult producer = consumer->get().cast<OpResult>(); 1418 if (!controlFn(producer, *consumer)) 1419 return failure(); 1420 } 1421 1422 auto linalgOp = cast<LinalgOp>(genericOp.getOperation()); 1423 SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes(); 1424 int64_t numElements = outputType.getNumElements(); 1425 1426 // Use APInt/APFloat instead of Attribute here for constructing the output. 1427 // This helps to avoid blowing up compiler memory usage: Attributes would 1428 // unify the following cases but they have lifetime as the MLIRContext. 1429 SmallVector<APInt> intOutputValues; 1430 SmallVector<APFloat> fpOutputValues; 1431 if (elementType.template isa<FloatType>()) 1432 fpOutputValues.resize(numElements, APFloat(0.f)); 1433 else 1434 intOutputValues.resize(numElements); 1435 1436 // Return the constant dim positions from the given permutation map. 1437 auto getDimPositions = [](AffineMap map) { 1438 SmallVector<unsigned> dims; 1439 dims.reserve(map.getNumResults()); 1440 for (AffineExpr result : map.getResults()) { 1441 dims.push_back(result.cast<AffineDimExpr>().getPosition()); 1442 } 1443 return dims; 1444 }; 1445 1446 SmallVector<SmallVector<unsigned>> inputDims; 1447 for (int i = 0; i < numInputs; ++i) 1448 inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i])); 1449 auto outputDims = getDimPositions(genericOp.getIndexingMaps().back()); 1450 auto outputShape = outputType.getShape(); 1451 1452 // Allocate small vectors for index delinearization. Initial values do not 1453 // matter here as they will be overwritten later. 1454 SmallVector<uint64_t> indices(loopBounds.size(), 0); 1455 SmallVector<uint64_t> dstIndices(loopBounds.size(), 0); 1456 SmallVector<SmallVector<uint64_t>> srcIndices( 1457 numInputs, SmallVector<uint64_t>(loopBounds.size(), 0)); 1458 SmallVector<uint64_t> srcLinearIndices(numInputs, 0); 1459 uint64_t dstLinearIndex = 0; 1460 1461 // Allocate spaces for compute function inputs. Initial values do not matter 1462 // here as they will be overwritten later. 1463 APIntOrFloatArray computeFnInputs; 1464 1465 auto inputShapes = llvm::to_vector<4>( 1466 llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { 1467 return operand->get().getType().cast<ShapedType>().getShape(); 1468 })); 1469 1470 // Given a `linearIndex`, remap it to a linear index to access linalg op 1471 // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, 1472 // `srcLinearIndices`, `dstLinearIndex` in place. 1473 auto computeRemappedLinearIndex = [&](int linearIndex) { 1474 int totalCount = linearIndex; 1475 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 1476 indices[dim] = totalCount % loopBounds[dim]; 1477 totalCount /= loopBounds[dim]; 1478 } 1479 1480 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 1481 for (int i = 0; i < numInputs; ++i) 1482 srcIndices[i][dim] = indices[inputDims[i][dim]]; 1483 dstIndices[dim] = indices[outputDims[dim]]; 1484 } 1485 1486 dstLinearIndex = dstIndices.front(); 1487 for (int i = 0; i < numInputs; ++i) 1488 srcLinearIndices[i] = srcIndices[i].front(); 1489 1490 for (int dim = 1; dim < outputType.getRank(); ++dim) { 1491 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; 1492 for (int i = 0; i < numInputs; ++i) 1493 srcLinearIndices[i] = 1494 srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; 1495 } 1496 }; 1497 1498 bool isFloat = elementType.isa<FloatType>(); 1499 if (isFloat) { 1500 SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges; 1501 for (int i = 0; i < numInputs; ++i) 1502 inFpRanges.push_back(inputValues[i].getValues<APFloat>()); 1503 1504 computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); 1505 1506 // Transpose the input constant. Because we don't know its rank in 1507 // advance, we need to loop over the range [0, element count) and 1508 // delinearize the index. 1509 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 1510 computeRemappedLinearIndex(linearIndex); 1511 1512 // Collect constant elements for all inputs at this loop iteration. 1513 for (int i = 0; i < numInputs; ++i) 1514 computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; 1515 1516 // Invoke the computation to get the corresponding constant output 1517 // element. 1518 fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; 1519 } 1520 } else { 1521 SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges; 1522 for (int i = 0; i < numInputs; ++i) 1523 inIntRanges.push_back(inputValues[i].getValues<APInt>()); 1524 1525 computeFnInputs.apInts.resize(numInputs); 1526 1527 // Transpose the input constant. Because we don't know its rank in 1528 // advance, we need to loop over the range [0, element count) and 1529 // delinearize the index. 1530 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 1531 computeRemappedLinearIndex(linearIndex); 1532 1533 // Collect constant elements for all inputs at this loop iteration. 1534 for (int i = 0; i < numInputs; ++i) 1535 computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; 1536 1537 // Invoke the computation to get the corresponding constant output 1538 // element. 1539 intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; 1540 } 1541 } 1542 1543 DenseElementsAttr outputAttr = 1544 isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) 1545 : DenseElementsAttr::get(outputType, intOutputValues); 1546 1547 rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr); 1548 return success(); 1549 } 1550 1551 private: 1552 ControlElementwiseOpsFusionFn controlFn; 1553 }; 1554 1555 // Folds linalg.generic ops that are actually transposes on constant values. 1556 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> { 1557 using FoldConstantBase::FoldConstantBase; 1558 1559 bool matchIndexingMaps(GenericOp genericOp) const { 1560 // We should have one input and one output. 1561 return genericOp.getIndexingMaps().size() == 2; 1562 } 1563 1564 RegionComputationFn getRegionComputeFn(GenericOp genericOp) const { 1565 // Make sure the region only contains a yield op. 1566 Block &body = genericOp.region().front(); 1567 if (!llvm::hasSingleElement(body)) 1568 return nullptr; 1569 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); 1570 if (!yieldOp) 1571 return nullptr; 1572 1573 // The yield op should return the block argument corresponds to the input. 1574 for (Value yieldVal : yieldOp.values()) { 1575 auto yieldArg = yieldVal.dyn_cast<BlockArgument>(); 1576 if (!yieldArg || yieldArg.getOwner() != &body) 1577 return nullptr; 1578 if (yieldArg.getArgNumber() != 0) 1579 return nullptr; 1580 } 1581 1582 // No computation; just return the orginal value. 1583 return [](const APIntOrFloatArray &inputs) { 1584 if (inputs.apFloats.empty()) 1585 return APIntOrFloat{inputs.apInts.front(), llvm::None}; 1586 return APIntOrFloat{llvm::None, inputs.apFloats.front()}; 1587 }; 1588 } 1589 1590 ControlElementwiseOpsFusionFn controlFn; 1591 }; 1592 1593 } // namespace 1594 1595 static Optional<SmallVector<Value>> 1596 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, 1597 GenericOp producer, 1598 const ControlElementwiseOpsFusionFn &controlFn) { 1599 if (producer->getNumResults() != 1) 1600 return llvm::None; 1601 1602 return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, 1603 rewriter); 1604 } 1605 1606 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, 1607 OpOperand &consumer) { 1608 if (auto producerCollapseOp = 1609 dyn_cast<linalg::TensorCollapseShapeOp>(producer.getOwner())) { 1610 return !isUnitDimExpansionOnly(producerCollapseOp); 1611 } 1612 if (auto consumerExpandOp = 1613 dyn_cast<linalg::TensorExpandShapeOp>(consumer.getOwner())) { 1614 return !isUnitDimExpansionOnly(consumerExpandOp); 1615 } 1616 return true; 1617 } 1618 1619 namespace { 1620 /// Patterns to fuse a generic op, with the producer of its operands. 1621 class FuseElementwiseOps : public OpRewritePattern<GenericOp> { 1622 public: 1623 FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, 1624 PatternBenefit benefit = 1) 1625 : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {} 1626 1627 LogicalResult matchAndRewrite(GenericOp genericOp, 1628 PatternRewriter &rewriter) const override { 1629 // Find the first operand that is defined by another generic op on tensors. 1630 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 1631 auto producer = 1632 dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp()); 1633 if (!producer || !producer.hasTensorSemantics()) 1634 continue; 1635 Optional<SmallVector<Value>> fusedOpResults = 1636 fuseElementwiseOps(rewriter, opOperand, producer, controlFn); 1637 if (fusedOpResults) { 1638 rewriter.replaceOp(genericOp, *fusedOpResults); 1639 return success(); 1640 } 1641 } 1642 return failure(); 1643 } 1644 1645 private: 1646 ControlElementwiseOpsFusionFn controlFn; 1647 }; 1648 1649 /// Pass that fuses generic ops on tensors. Used only for testing. 1650 struct LinalgElementwiseOpFusionPass 1651 : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> { 1652 void runOnOperation() override { 1653 Operation *op = getOperation(); 1654 RewritePatternSet patterns(op->getContext()); 1655 ControlElementwiseOpsFusionFn allowFoldingFn = 1656 [](const OpResult &producer, const OpOperand &consumer) { 1657 return true; 1658 }; 1659 populateElementwiseOpsFusionPatterns( 1660 patterns, 1661 LinalgElementwiseFusionOptions().setControlFoldingReshapes( 1662 allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape)); 1663 1664 // Use TopDownTraversal for compile time reasons 1665 GreedyRewriteConfig grc; 1666 grc.useTopDownTraversal = true; 1667 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), 1668 grc); 1669 } 1670 }; 1671 1672 /// Pass to test folding of reshape ops with generic ops by linearization. 1673 struct FoldReshapeOpsByLinearizationPass 1674 : public LinalgFoldReshapeOpsByLinearizationBase< 1675 FoldReshapeOpsByLinearizationPass> { 1676 void runOnOperation() override { 1677 Operation *op = getOperation(); 1678 RewritePatternSet patterns(op->getContext()); 1679 populateFoldReshapeOpsByLinearizationPatterns(patterns); 1680 if (allowFoldingUnitDimReshapes) { 1681 populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns); 1682 } 1683 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); 1684 } 1685 }; 1686 1687 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if 1688 /// the value of the `outs` operand is not used within the op. This is only 1689 /// implemented for `linalg.generic` operations for now, but should hold for all 1690 /// linalg structured ops. 1691 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> { 1692 using OpRewritePattern<GenericOp>::OpRewritePattern; 1693 1694 LogicalResult matchAndRewrite(GenericOp op, 1695 PatternRewriter &rewriter) const override { 1696 rewriter.startRootUpdate(op); 1697 bool modifiedOutput = false; 1698 Location loc = op.getLoc(); 1699 for (OpOperand *opOperand : op.getOutputOperands()) { 1700 if (!op.payloadUsesValueFromOperand(opOperand)) { 1701 Value operandVal = opOperand->get(); 1702 auto operandType = operandVal.getType().dyn_cast<RankedTensorType>(); 1703 if (!operandType) 1704 continue; 1705 1706 // If outs is already an `init_tensor` operation, nothing to do. 1707 auto definingOp = operandVal.getDefiningOp<InitTensorOp>(); 1708 if (definingOp) 1709 continue; 1710 modifiedOutput = true; 1711 SmallVector<Value> dynamicDims; 1712 for (auto dim : llvm::enumerate(operandType.getShape())) { 1713 if (dim.value() != ShapedType::kDynamicSize) 1714 continue; 1715 dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>( 1716 loc, operandVal, dim.index())); 1717 } 1718 Value initTensor = rewriter.create<InitTensorOp>( 1719 loc, dynamicDims, operandType.getShape(), 1720 operandType.getElementType()); 1721 op->setOperand(opOperand->getOperandNumber(), initTensor); 1722 } 1723 } 1724 if (!modifiedOutput) { 1725 rewriter.cancelRootUpdate(op); 1726 return failure(); 1727 } 1728 rewriter.finalizeRootUpdate(op); 1729 return success(); 1730 } 1731 }; 1732 1733 } // namespace 1734 1735 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( 1736 RewritePatternSet &patterns) { 1737 patterns 1738 .add<FoldProducerReshapeOpByLinearization<false, TensorCollapseShapeOp>, 1739 FoldProducerReshapeOpByLinearization<false, TensorExpandShapeOp>, 1740 FoldConsumerReshapeOpByLinearization<false, TensorCollapseShapeOp>, 1741 FoldConsumerReshapeOpByLinearization<false, TensorExpandShapeOp>>( 1742 patterns.getContext()); 1743 } 1744 1745 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( 1746 RewritePatternSet &patterns) { 1747 patterns 1748 .add<FoldProducerReshapeOpByLinearization<true, TensorCollapseShapeOp>, 1749 FoldProducerReshapeOpByLinearization<true, TensorExpandShapeOp>, 1750 FoldConsumerReshapeOpByLinearization<true, TensorCollapseShapeOp>, 1751 FoldConsumerReshapeOpByLinearization<true, TensorExpandShapeOp>>( 1752 patterns.getContext()); 1753 } 1754 1755 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( 1756 RewritePatternSet &patterns, 1757 ControlElementwiseOpsFusionFn controlFoldingReshapes) { 1758 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(), 1759 controlFoldingReshapes); 1760 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), 1761 controlFoldingReshapes); 1762 } 1763 1764 void mlir::linalg::populateElementwiseOpsFusionPatterns( 1765 RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { 1766 auto *context = patterns.getContext(); 1767 patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant, 1768 FoldConstantTranspose>(context, 1769 options.controlElementwiseOpsFusionFn); 1770 patterns.add<RemoveOutsDependency>(context); 1771 populateFoldReshapeOpsByExpansionPatterns(patterns, 1772 options.controlFoldingReshapesFn); 1773 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 1774 GenericOp::getCanonicalizationPatterns(patterns, context); 1775 TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context); 1776 TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context); 1777 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns( 1778 patterns); 1779 } 1780 1781 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { 1782 auto *context = patterns.getContext(); 1783 patterns.add<PushExpandingReshape>(context); 1784 } 1785 1786 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() { 1787 return std::make_unique<LinalgElementwiseOpFusionPass>(); 1788 } 1789 1790 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() { 1791 return std::make_unique<FoldReshapeOpsByLinearizationPass>(); 1792 } 1793