1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion on tensors operations pass. 10 // 11 //===----------------------------------------------------------------------===// 12 #include "PassDetail.h" 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Passes.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/IR/AffineExpr.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Support/LLVM.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 28 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 29 /// the `producer` to use in the fused operation given the indexing map of the 30 /// result of the producer in the consumer. 31 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 32 OpOperand *producerOpOperand, AffineMap producerResultIndexMap, 33 AffineMap fusedConsumerArgIndexMap) { 34 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 35 // from consumer loop -> consumer arg tensor index/producer result tensor 36 // index. The fused loop is same as the consumer loop. For each producer arg 37 // the indexing map to be computed is a map from consumer loop -> producer 38 // arg tensor index. 39 // producerResultIndexMap is a map from producer loop -> tensor index. 40 // Compute the inverse to get map from tensor index -> producer loop. 41 // The inverse is a map from producer result tensor index -> producer loop. 42 AffineMap invProducerResultIndexMap = 43 inversePermutation(producerResultIndexMap); 44 assert(invProducerResultIndexMap && 45 "expected producer result indexig map to be invertible"); 46 47 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner()); 48 // argMap is a map from producer loop -> producer arg tensor index. 49 AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); 50 51 // Compose argMap with invProducerResultIndexMap to get a map from 52 // producer result tensor index -> producer arg tensor index. 53 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 54 55 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 56 // consumer loop/ fused loop -> producer arg tensor index. 57 return t1.compose(fusedConsumerArgIndexMap); 58 } 59 60 /// Conditions for elementwise fusion of generic operations. 61 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, 62 OpOperand *consumerOpOperand) { 63 // Producer and consumer must have tensor semantics. 64 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 65 return false; 66 67 // Verify that 68 // - the producer has all "parallel" iterator type. 69 if (producer.getNumParallelLoops() != producer.getNumLoops()) 70 return false; 71 72 // Only allow fusing the producer of an input operand for now. 73 // TODO: allow fusing the producer of an output operand. 74 if (!consumer.isInputTensor(consumerOpOperand)) 75 return false; 76 77 // Get the consumer index map. The number of results of the consumer index 78 // map must match the number of loops of the producer. 79 AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); 80 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 81 return false; 82 83 // Currently support only operations with single result. 84 if (producer.getNumOutputs() != 1) 85 return false; 86 87 // Finally the index_map for the result must be invertible. For now just 88 // verify it is a permutation. 89 AffineMap producerResultIndexMap = 90 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 91 if (!producerResultIndexMap.isPermutation()) 92 return false; 93 94 // Ensure that the fusion does not remove size information required to 95 // get the loop bounds. For non-reduction generics, this is trivially the 96 // case due to the output operand. For reductions, we need to check that after 97 // the fusion, each loop dimension has at least one input that defines it. 98 if ((consumer.getNumReductionLoops())) { 99 llvm::BitVector coveredDims(consumer.getNumLoops(), false); 100 101 auto addToCoveredDims = [&](AffineMap map) { 102 for (auto result : map.getResults()) 103 if (auto dimExpr = result.dyn_cast<AffineDimExpr>()) 104 coveredDims[dimExpr.getPosition()] = true; 105 }; 106 107 for (auto pair : 108 llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) { 109 Value operand = std::get<0>(pair); 110 if (operand == consumerOpOperand->get()) 111 continue; 112 AffineMap operandMap = std::get<1>(pair); 113 addToCoveredDims(operandMap); 114 } 115 116 for (OpOperand *operand : producer.getInputOperands()) { 117 AffineMap newIndexingMap = 118 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 119 operand, producerResultIndexMap, consumerIndexMap); 120 addToCoveredDims(newIndexingMap); 121 } 122 if (!coveredDims.all()) 123 return false; 124 } 125 126 return true; 127 } 128 129 /// Generate the region of the fused tensor operation. The region of the fused 130 /// op must be empty. 131 static void 132 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, 133 AffineMap consumerToProducerLoopsMap, 134 OpOperand *consumerOpOperand, 135 unsigned nloops) { 136 auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp()); 137 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 138 // Build the region of the fused op. 139 Block &producerBlock = producer->getRegion(0).front(); 140 Block &consumerBlock = consumer->getRegion(0).front(); 141 Block *fusedBlock = new Block(); 142 fusedOp.region().push_back(fusedBlock); 143 BlockAndValueMapping mapper; 144 OpBuilder::InsertionGuard guard(rewriter); 145 rewriter.setInsertionPointToStart(fusedBlock); 146 147 // 2. Add an index operation for every fused loop dimension and use the 148 // `consumerToProducerLoopsMap` to map the producer indices. 149 if (producer.hasIndexSemantics()) { 150 // Add an index operation for every fused loop dimension. 151 unsigned numFusedOpLoops = 152 std::max(producer.getNumLoops(), consumer.getNumLoops()); 153 SmallVector<Value> fusedIndices; 154 fusedIndices.reserve(numFusedOpLoops); 155 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops), 156 std::back_inserter(fusedIndices), [&](uint64_t dim) { 157 return rewriter.create<IndexOp>(producer.getLoc(), dim); 158 }); 159 for (IndexOp indexOp : 160 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) { 161 Value newIndex = rewriter.create<mlir::AffineApplyOp>( 162 producer.getLoc(), 163 consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices); 164 mapper.map(indexOp.getResult(), newIndex); 165 } 166 } 167 // TODO: allow fusing the producer of an output operand. 168 assert(consumer.isInputTensor(consumerOpOperand) && 169 "expected producer of input operand"); 170 // 3. Consumer input operands up to consumerIdx (exclusive). 171 for (BlockArgument bbArg : consumerBlock.getArguments().take_front( 172 consumerOpOperand->getOperandNumber())) // input assumption. 173 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 174 175 // Replacing consumerIdx requires getting the cloned, yielded, value from 176 // the (cloned) producer block. This happens in step 9. 177 178 // 4. Splice in producer's input operands. 179 for (BlockArgument bbArg : 180 producerBlock.getArguments().take_front(producer.getNumInputs())) 181 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 182 183 // 4.b. Producer output operand/map that is fused needs to be mapped to the 184 // producer bbArg if it is an "initTensor" (i.e. its value is actually read). 185 assert(producer->getNumResults() == 1 && "expected single result producer"); 186 if (producer.isInitTensor(producer.getOutputOperand(0))) { 187 BlockArgument bbArg = producerBlock.getArguments() 188 .drop_front(producer.getNumInputs()) 189 // TODO: bbArg index of 190 .front(); 191 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 192 } 193 // 5. Remaining consumer's input operands (drop past index `consumerIdx`). 194 for (BlockArgument bbArg : 195 consumerBlock.getArguments() 196 .take_front(consumer.getNumInputs()) 197 .drop_front(consumerOpOperand->getOperandNumber() + 1)) 198 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 199 // 6. All of consumer's output operands. 200 for (BlockArgument bbArg : 201 consumerBlock.getArguments().take_back(consumer.getNumOutputs())) 202 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 203 // 7. All of producer's output operands except the one fused. 204 // TODO: allow fusion of multi-result producers. 205 assert(producer->getNumResults() == 1 && "expected single result producer"); 206 207 // 8. Clone all producer operations except for the yield and index operations 208 // to the fused operation. 209 for (auto &op : producerBlock.without_terminator()) { 210 if (!isa<IndexOp>(op)) 211 rewriter.clone(op, mapper); 212 } 213 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just 214 // forward the yield operand. 215 auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator()); 216 // TODO: allow fusion of multi-result producers. 217 assert(producer->getNumResults() == 1 && "expected single result producer"); 218 unsigned producerResultNumber = 0; 219 Value replacement = 220 mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); 221 // Sanity checks, if replacement is not already in the mapper then it must be 222 // produced outside. 223 if (replacement == yieldOp.getOperand(producerResultNumber)) { 224 if (auto bb = replacement.dyn_cast<BlockArgument>()) 225 assert(bb.getOwner() != &producerBlock && 226 "yielded block argument must have been mapped"); 227 else 228 assert(!producer->isAncestor(replacement.getDefiningOp()) && 229 "yielded value must have been mapped"); 230 } 231 mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), 232 replacement); 233 // 10. Clone operations from the consumer to the fused op. 234 for (auto &op : consumerBlock.getOperations()) 235 rewriter.clone(op, mapper); 236 237 // Sanity checks. 238 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && 239 "Ill-formed GenericOp region"); 240 } 241 242 static Optional<SmallVector<Value>> 243 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, 244 const ControlElementwiseOpsFusionFn &controlFn, 245 PatternRewriter &rewriter) { 246 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 247 if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || 248 !controlFn(producer->getResult(0), *consumerOpOperand)) 249 return llvm::None; 250 251 // TODO: allow fusing the producer of an output operand. 252 assert(consumer.isInputTensor(consumerOpOperand) && 253 "expected producer of input operand"); 254 255 // Compute the fused operands list and indexing maps. 256 SmallVector<Value> fusedOperands; 257 SmallVector<AffineMap> fusedIndexMaps; 258 fusedOperands.reserve(producer->getNumOperands() + 259 consumer->getNumOperands()); 260 fusedIndexMaps.reserve(producer->getNumOperands() + 261 consumer->getNumOperands()); 262 // In the following, numbering matches that of `generateFusedTensorOpRegion`. 263 // 3. Consumer input operands/maps up to consumerIdx (exclusive). 264 SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands(); 265 SmallVector<OpOperand *>::iterator it = 266 llvm::find(consumerInputs, consumerOpOperand); 267 assert(it != consumerInputs.end() && "expected to find the consumer operand"); 268 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { 269 fusedOperands.push_back(opOperand->get()); 270 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 271 } 272 // 4. Splice in producer's input operands/maps. 273 assert(producer->getNumResults() == 1 && "expected single result producer"); 274 AffineMap producerResultIndexMap = 275 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 276 for (OpOperand *opOperand : producer.getInputOperands()) { 277 fusedOperands.push_back(opOperand->get()); 278 // Compute indexing maps for the producer args in the fused operation. 279 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 280 opOperand, producerResultIndexMap, 281 consumer.getTiedIndexingMap(consumerOpOperand)); 282 fusedIndexMaps.push_back(map); 283 } 284 // 4.b. Producer output operand/map that is fused needs to be passed if it is 285 // an "initTensor" (i.e. its value is actually read). 286 assert(producer->getNumResults() == 1 && "expected single result producer"); 287 if (producer.isInitTensor(producer.getOutputOperand(0))) { 288 fusedOperands.push_back(producer.getOutputOperand(0)->get()); 289 // Compute indexing maps for the producer args in the fused operation. 290 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 291 producer.getOutputOperand(0), producerResultIndexMap, 292 consumer.getTiedIndexingMap(consumerOpOperand)); 293 fusedIndexMaps.push_back(map); 294 } 295 // 5. Remaining consumer's input operands/maps (drop past index 296 // `consumerIdx`). 297 for (OpOperand *opOperand : 298 llvm::make_range(std::next(it), consumerInputs.end())) { 299 fusedOperands.push_back(opOperand->get()); 300 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 301 } 302 // 6. All of consumer's output operands (skip operands: added by the builder). 303 for (OpOperand *opOperand : consumer.getOutputOperands()) 304 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 305 // 7. All of producer's output operands/maps except the one fused. 306 // TODO: allow fusion of multi-result producers. 307 assert(producer->getNumResults() == 1 && "expected single result producer"); 308 309 // Generate the fused op. 310 SmallVector<Value> consumerOutputs = consumer.getOutputOperands(); 311 auto fusedOp = rewriter.create<GenericOp>( 312 consumer.getLoc(), consumer->getResultTypes(), 313 /*inputs=*/fusedOperands, 314 // TODO: handle outputs. 315 consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 316 consumer.iterator_types(), 317 /*doc=*/nullptr, 318 /*library_call=*/nullptr); 319 320 // Construct an AffineMap from consumer loops to producer loops. 321 // consumer loop -> tensor index 322 AffineMap consumerResultIndexMap = 323 consumer.getTiedIndexingMap(consumerOpOperand); 324 // tensor index -> producer loop 325 AffineMap invProducerResultIndexMap = 326 inversePermutation(producerResultIndexMap); 327 assert(invProducerResultIndexMap && 328 "expected producer result indexig map to be invertible"); 329 // consumer loop -> producer loop 330 AffineMap consumerToProducerLoopsMap = 331 invProducerResultIndexMap.compose(consumerResultIndexMap); 332 333 generateFusedElementwiseOpRegion(rewriter, fusedOp, 334 consumerToProducerLoopsMap, 335 consumerOpOperand, consumer.getNumLoops()); 336 return SmallVector<Value>(fusedOp->getResults()); 337 } 338 339 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` 340 /// provided, given the shape of the source tensor that corresponds to the 341 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions 342 /// are "row-major" ordered logically. 343 /// 344 /// For example: 345 /// 346 /// %0 = op ... : tensor<?x?x4x5xf32> 347 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` 348 /// 349 /// and reshape: 350 /// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] : 351 /// tensor<?x?x4x5xf32> into tensor<?x?xf32> 352 /// 353 /// would be rewritten into: 354 /// %0 = op ... : tensor<?x?x4x5xf32> 355 /// with output index_map 356 /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` 357 template <typename TensorReshapeOp> 358 static AffineMap linearizeCollapsedDims(AffineMap sourceMap, 359 TensorReshapeOp reshapeOp) { 360 constexpr bool isExpanding = 361 std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value; 362 ArrayRef<int64_t> sourceShape = 363 (isExpanding ? reshapeOp.getResultType().getShape() 364 : reshapeOp.getSrcType().getShape()); 365 SmallVector<AffineExpr> resultExprs; 366 ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults(); 367 MLIRContext *context = sourceMap.getContext(); 368 369 // Compute the result exprs based on the reassociation maps. 370 for (auto &indices : reshapeOp.getReassociationIndices()) { 371 // Assume that they are in-order and contiguous (already checked in 372 // verifier). 373 assert(!indices.empty()); 374 SmallVector<int64_t> sizes; 375 SmallVector<AffineExpr> dimExprs; 376 for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()), 377 sourceExprs.slice(indices[0], indices.size()))) { 378 if (std::get<0>(en) == 1) 379 continue; 380 sizes.push_back(std::get<0>(en)); 381 dimExprs.push_back(std::get<1>(en)); 382 } 383 AffineExpr linearizedExpr = 384 makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); 385 resultExprs.push_back(linearizedExpr); 386 } 387 // The new affine map cannot drop unused dimension but some new symbols may 388 // have been added. Create a map with at least as many dimensions/symbols as 389 // the original affine map. 390 int64_t maxDim = -1; 391 int64_t maxSym = -1; 392 getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym); 393 unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims()); 394 unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols()); 395 return AffineMap::get(numDims, numSyms, resultExprs, context); 396 } 397 398 // tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a 399 // producer). Fusing when operand has higher rank will require use of mods and 400 // divs in the indexing maps of the fused op which would make it non-invertible. 401 static bool isTensorReshapeOpFoldableByLinearization( 402 tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) { 403 if (!asProducer) 404 return false; 405 return useIndexMap.isPermutation(); 406 } 407 408 // tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a 409 // consumer). 410 static bool 411 isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp, 412 AffineMap useIndexMap, 413 bool asProducer) { 414 if (asProducer) 415 return false; 416 return useIndexMap.isPermutation(); 417 } 418 419 /// Check if the reshape operation is only expansion into/collapsing of 420 /// unit-dimension. 421 template <typename TensorReshapeOp> 422 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) { 423 constexpr bool isExpanding = 424 std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value; 425 ArrayRef<int64_t> expandedShape = 426 (isExpanding ? reshapeOp.getResultType().getShape() 427 : reshapeOp.getSrcType().getShape()); 428 for (auto &indices : reshapeOp.getReassociationIndices()) { 429 unsigned numUnitDims = 0; 430 for (int64_t position : indices) 431 if (expandedShape[position] == 1) 432 numUnitDims++; 433 if (numUnitDims != indices.size() - 1) 434 return false; 435 } 436 return true; 437 } 438 439 /// Conditions for folding a generic operation with a reshape op by expanding 440 /// the iteration space dimensionality for tensor operations. These are 441 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the 442 /// following fusion pattern. 443 /// 444 /// Consider 445 /// 446 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>) 447 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 448 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 449 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] 450 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] 451 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 452 /// 453 /// The reshape can be folded into the `genericOp` if its loop dimensionality 454 /// is increased to match the result (operand) of the tensor_expand_shape. 455 /// The indexing_map of the fused tensor in the `genericOp` and the 456 /// reassociation map helps compute the indexing maps of the modified op. 457 /// For the above example, based on the reassociation map it 458 /// can be concluded that 459 /// 460 /// - The loop used to access the first dimension of the fused tensor is split 461 /// into two. 462 /// - The loop used to access the second dimension of the fused tensor is kept 463 /// as is. 464 /// - The loop used to access the third dimension of the fused tensor is split 465 /// into three. 466 /// 467 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified 468 /// op, then 469 /// 470 /// d0 -> e0, e1 471 /// d1 -> e2, e3, e4 472 /// d2 -> e5 473 /// 474 /// substituting this, the generic op can be rewritten as 475 /// 476 /// %d = linalg.generic ins(%0, %1 : ) 477 /// indexing_maps = 478 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, 479 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, 480 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] 481 /// 482 /// Since operands to the linalg generic are now 5D, reshapes can be introduced 483 /// to make it consistent 484 /// 485 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]] 486 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 487 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]] 488 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 489 /// 490 /// The added reshapes are again expanding patterns, so they will get fused 491 /// with its producers if possible. 492 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, 493 OpOperand *fusableOpOperand) { 494 // Is fusable only if: 495 // - All the indexing maps for operands and results are projected 496 // permutations. 497 // - The fused tensor is not a scalar. 498 // - All the loops are parallel loops. 499 return genericOp.hasTensorSemantics() && 500 llvm::all_of(genericOp.indexing_maps().getValue(), 501 [](Attribute attr) { 502 return attr.cast<AffineMapAttr>() 503 .getValue() 504 .isProjectedPermutation(); 505 }) && 506 genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && 507 llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { 508 return attr.cast<StringAttr>().getValue() == 509 getParallelIteratorTypeName(); 510 }); 511 } 512 513 namespace { 514 /// Information needed to expand a generic operation to fold the reshape with 515 /// it. 516 class ExpansionInfo { 517 public: 518 // Computes the mapping from original dimensions of the op to the dimensions 519 // of the expanded op given the `indexingMap` of the fused operand/result of 520 // the generic op, the `reassocationMaps` of the reshape op and the shape of 521 // the expanded op. 522 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, 523 ArrayRef<AffineMap> reassociationMaps, 524 ArrayRef<int64_t> expandedShape, 525 PatternRewriter &rewriter); 526 unsigned getOrigOpNumDims() const { return reassociation.size(); } 527 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } 528 ReassociationIndicesRef getExpandedDims(unsigned i) const { 529 return reassociation[i]; 530 } 531 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { 532 return expandedShapeMap[i]; 533 } 534 535 private: 536 /// Reassociation from the dimensions in the original operation to the 537 /// dimension of the expanded operation. 538 SmallVector<ReassociationIndices> reassociation; 539 /// Mapping from extent of loops in the original operation, to the extent of 540 /// loops in the expanded operation. 541 SmallVector<SmallVector<int64_t>> expandedShapeMap; 542 unsigned expandedOpNumDims; 543 }; 544 } // namespace 545 546 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, 547 OpOperand *fusableOpOperand, 548 ArrayRef<AffineMap> reassociationMaps, 549 ArrayRef<int64_t> expandedShape, 550 PatternRewriter &rewriter) { 551 if (reassociationMaps.empty()) 552 return failure(); 553 AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); 554 555 Optional<SmallVector<int64_t, 4>> originalLoopRange = 556 linalgOp.getStaticLoopRanges(); 557 if (!originalLoopRange) 558 return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range"); 559 560 reassociation.clear(); 561 expandedShapeMap.clear(); 562 // Compute the number of dimension in the expanded op that correspond to each 563 // dimension of the original op. 564 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1); 565 expandedShapeMap.resize(fusedIndexMap.getNumDims()); 566 for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { 567 unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition(); 568 AffineMap foldedDims = reassociationMaps[resultExpr.index()]; 569 numExpandedDims[pos] = foldedDims.getNumResults(); 570 ArrayRef<int64_t> shape = 571 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); 572 expandedShapeMap[pos].assign(shape.begin(), shape.end()); 573 } 574 // The remaining dimensions remain the same. 575 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims())) 576 if (expandedShapeMap[i].empty()) 577 expandedShapeMap[i] = {(*originalLoopRange)[i]}; 578 579 // Compute reassociation map from the original op to the expanded op. 580 unsigned sum = 0; 581 reassociation.reserve(fusedIndexMap.getNumDims()); 582 for (auto numFoldedDim : llvm::enumerate(numExpandedDims)) { 583 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value()); 584 reassociation.emplace_back(seq.begin(), seq.end()); 585 sum += numFoldedDim.value(); 586 } 587 expandedOpNumDims = sum; 588 return success(); 589 } 590 591 /// Epanding the body of a linalg operation requires adaptations of the accessed 592 /// loop indices. Specifically, access of indices in the original operation need 593 /// to be replaced with linearizations of indices in the expanded op. That 594 /// requires the shape of the expanded dimensions to be static (at least all but 595 /// the most significant). For now check that these are all statically sized. 596 /// Note that this could be extended to handle dynamic case, but the 597 /// implementation below uses `affine.apply` which seems to have issues when the 598 /// shapes are not static. 599 LogicalResult isGenericOpExpandable(GenericOp genericOp, 600 const ExpansionInfo &expansionInfo, 601 PatternRewriter &rewriter) { 602 if (!genericOp.hasIndexSemantics()) 603 return success(); 604 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 605 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 606 if (expandedShape.size() == 1) 607 continue; 608 for (int64_t shape : expandedShape.drop_front()) { 609 if (ShapedType::isDynamic(shape)) { 610 return rewriter.notifyMatchFailure( 611 genericOp, "cannot expand due to index semantics and dynamic dims"); 612 } 613 } 614 } 615 return success(); 616 } 617 618 /// Return the indexing map to use in the expanded op for a given the 619 /// `indexingMap` of the original operation. 620 static AffineMap 621 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, 622 const ExpansionInfo &expansionInfo) { 623 SmallVector<AffineExpr> newExprs; 624 for (AffineExpr expr : indexingMap.getResults()) { 625 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 626 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>( 627 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { 628 return builder.getAffineDimExpr(static_cast<unsigned>(v)); 629 })); 630 newExprs.append(expandedExprs.begin(), expandedExprs.end()); 631 } 632 return AffineMap::get(expansionInfo.getExpandedOpNumDims(), 633 indexingMap.getNumSymbols(), newExprs, 634 builder.getContext()); 635 } 636 637 /// Return the type of the operand/result to use in the expanded op given the 638 /// type in the original op. 639 static RankedTensorType getExpandedType(RankedTensorType originalType, 640 AffineMap indexingMap, 641 const ExpansionInfo &expansionInfo) { 642 SmallVector<int64_t> expandedShape; 643 for (AffineExpr expr : indexingMap.getResults()) { 644 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 645 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); 646 expandedShape.append(dimExpansion.begin(), dimExpansion.end()); 647 } 648 return RankedTensorType::get(expandedShape, originalType.getElementType()); 649 } 650 651 /// Returns the reassociation maps to use in the `tensor.expand_shape` 652 /// operation to convert the operands of the original operation to operands of 653 /// the expanded operation. The same method is used to compute the 654 /// `tensor.collapse_shape` used to collapse the result of the expanded 655 /// op to get the value that can replace all uses of the results of the original 656 /// op. 657 static SmallVector<ReassociationIndices> 658 getReassociationForExpansion(AffineMap indexingMap, 659 const ExpansionInfo &expansionInfo) { 660 SmallVector<ReassociationIndices> reassociation; 661 unsigned numReshapeDims = 0; 662 for (AffineExpr expr : indexingMap.getResults()) { 663 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 664 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); 665 SmallVector<int64_t, 2> indices = llvm::to_vector<2>( 666 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims)); 667 reassociation.emplace_back(std::move(indices)); 668 numReshapeDims += numExpandedDims; 669 } 670 return reassociation; 671 } 672 673 /// Update the body of an expanded linalg operation having index semantics. The 674 /// indices of the original operation need to be recovered by linearizing the 675 /// indices of the correspoding dimensions of the expanded operation. For now it 676 /// is assumed that the shapes of the expanded operation needed for 677 /// linearization are static. 678 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, 679 Location loc, Region &fusedRegion, 680 const ExpansionInfo &expansionInfo) { 681 // Replace the original indices by the linearization of the expanded indices. 682 for (IndexOp indexOp : 683 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) { 684 ArrayRef<int64_t> expandedDims = 685 expansionInfo.getExpandedDims(indexOp.dim()); 686 assert(!expandedDims.empty() && "expected valid expansion info"); 687 688 // Skip index operations that are not affected by the expansion. 689 if (expandedDims.size() == 1 && 690 expandedDims.front() == (int64_t)indexOp.dim()) 691 continue; 692 693 // Linearize the expanded indices of the original index dimension. 694 OpBuilder::InsertionGuard guard(rewriter); 695 rewriter.setInsertionPointAfter(indexOp); 696 ArrayRef<int64_t> expandedDimsShape = 697 expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front(); 698 SmallVector<Value> expandedIndices; 699 expandedIndices.reserve(expandedDims.size() - 1); 700 llvm::transform( 701 expandedDims.drop_front(), std::back_inserter(expandedIndices), 702 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); 703 Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front()); 704 for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { 705 assert(!ShapedType::isDynamic(std::get<0>(it))); 706 AffineExpr idx, acc; 707 bindDims(rewriter.getContext(), idx, acc); 708 newIndex = rewriter.create<AffineApplyOp>( 709 indexOp.getLoc(), idx + acc * std::get<0>(it), 710 ValueRange{std::get<1>(it), newIndex}); 711 } 712 rewriter.replaceOp(indexOp, newIndex); 713 } 714 } 715 716 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op 717 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes 718 /// that those conditions have been satisfied. 719 static Optional<SmallVector<Value>> 720 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, 721 OpOperand *fusableOpOperand, 722 PatternRewriter &rewriter) { 723 assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) && 724 "preconditions for fuse operation failed"); 725 // Check if reshape is expanding or collapsing. 726 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp); 727 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp); 728 bool isExpanding = (expandingReshapeOp != nullptr); 729 RankedTensorType expandedType = isExpanding 730 ? expandingReshapeOp.getResultType() 731 : collapsingReshapeOp.getSrcType(); 732 733 ExpansionInfo expansionInfo; 734 if (failed(expansionInfo.compute( 735 genericOp, fusableOpOperand, 736 isExpanding ? expandingReshapeOp.getReassociationMaps() 737 : collapsingReshapeOp.getReassociationMaps(), 738 expandedType.getShape(), rewriter))) 739 return llvm::None; 740 741 if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) 742 return llvm::None; 743 744 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( 745 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) { 746 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); 747 })); 748 749 SmallVector<Value> expandedOpOperands; 750 expandedOpOperands.reserve(genericOp.getNumInputs()); 751 for (OpOperand *opOperand : genericOp.getInputOperands()) { 752 if (opOperand == fusableOpOperand) { 753 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() 754 : collapsingReshapeOp.src()); 755 continue; 756 } 757 if (genericOp.isInputTensor(opOperand)) { 758 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 759 RankedTensorType expandedOperandType = 760 getExpandedType(opOperand->get().getType().cast<RankedTensorType>(), 761 indexingMap, expansionInfo); 762 if (expandedOperandType != opOperand->get().getType()) { 763 // Reshape the operand to get the right type. 764 SmallVector<ReassociationIndices> reassociation = 765 getReassociationForExpansion(indexingMap, expansionInfo); 766 expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>( 767 genericOp.getLoc(), expandedOperandType, opOperand->get(), 768 reassociation)); 769 continue; 770 } 771 } 772 expandedOpOperands.push_back(opOperand->get()); 773 } 774 775 Location loc = genericOp.getLoc(); 776 SmallVector<Value> outputs; 777 for (OpOperand *opOperand : genericOp.getOutputOperands()) { 778 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 779 RankedTensorType expandedOutputType = 780 getExpandedType(opOperand->get().getType().cast<RankedTensorType>(), 781 indexingMap, expansionInfo); 782 if (expandedOutputType != opOperand->get().getType()) { 783 SmallVector<ReassociationIndices> reassociation = 784 getReassociationForExpansion(indexingMap, expansionInfo); 785 outputs.push_back(rewriter.create<tensor::ExpandShapeOp>( 786 genericOp.getLoc(), expandedOutputType, opOperand->get(), 787 reassociation)); 788 } 789 } 790 791 // The iterator types of the expanded op are all parallel. 792 SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(), 793 getParallelIteratorTypeName()); 794 795 TypeRange resultTypes = ValueRange(outputs).getTypes(); 796 auto fusedOp = 797 rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes, 798 /*inputs=*/expandedOpOperands, outputs, 799 expandedOpIndexingMaps, iteratorTypes); 800 Region &fusedRegion = fusedOp->getRegion(0); 801 Region &originalRegion = genericOp->getRegion(0); 802 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); 803 804 // Update the index accesses after the expansion. 805 updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); 806 807 // Reshape the result values to their original shape if this is a collapsing 808 // reshape folded into its consumer. 809 SmallVector<Value> resultVals; 810 for (OpResult opResult : genericOp->getOpResults()) { 811 int64_t resultNumber = opResult.getResultNumber(); 812 if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { 813 SmallVector<ReassociationIndices> reassociation = 814 getReassociationForExpansion( 815 genericOp.getTiedIndexingMap( 816 genericOp.getOutputOperand(resultNumber)), 817 expansionInfo); 818 resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>( 819 genericOp.getLoc(), opResult.getType(), 820 fusedOp->getResult(resultNumber), reassociation)); 821 } else { 822 resultVals.push_back(fusedOp->getResult(resultNumber)); 823 } 824 } 825 // Assuming a single result. 826 return resultVals; 827 } 828 829 namespace { 830 831 /// Pattern to fold tensor_expand_shape op with its consumer by using the source 832 /// of the reshape op as the operand in the consumer (instead of the result of 833 /// the tensor_collapse_shape). The corresponding index map in the consumer 834 /// needs to be modified to linearize the folded dimension. 835 /// 836 /// For example, 837 /// 838 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 839 /// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] 840 /// tensor<?x?x?xf32> into tensor<?x?x4x?xf32> 841 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } 842 /// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ... 843 /// -> tensor<?x?x4x?xf32> 844 /// 845 /// can be folded into 846 /// 847 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> 848 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 849 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } 850 /// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ... 851 /// -> tensor<?x?x4x?xf32> 852 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 853 struct FoldProducerReshapeOpByLinearization 854 : public OpRewritePattern<GenericOp> { 855 using OpRewritePattern<GenericOp>::OpRewritePattern; 856 857 LogicalResult matchAndRewrite(GenericOp genericOp, 858 PatternRewriter &rewriter) const override { 859 if (!genericOp.hasTensorSemantics()) 860 return failure(); 861 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 862 for (auto en : llvm::enumerate(inputOperands)) { 863 auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>(); 864 if (!reshapeOp) 865 continue; 866 867 if (!isTensorReshapeOpFoldableByLinearization( 868 reshapeOp, genericOp.getTiedIndexingMap(en.value()), 869 /*asProducer =*/true) || 870 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 871 continue; 872 873 // Compute the fused operands list, 874 SmallVector<Value> fusedOperands = genericOp.getInputOperands(); 875 fusedOperands[en.index()] = reshapeOp.src(); 876 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 877 llvm::append_range(fusedOperands, outputOperands); 878 879 // Compute indexing_maps for the fused operation. The indexing_maps for 880 // the operands of the consumers that arent fused are the same. 881 SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps(); 882 883 // Accepted consumer maps are either identity or permutation. 884 auto invMap = inversePermutation(fusedIndexMaps[en.index()]); 885 886 // Compute the indexing map to use for the result of the producer. 887 AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp); 888 // The modified map cannot have symbols. 889 if (modifiedMap.getNumSymbols()) 890 return failure(); 891 for (AffineExpr expr : modifiedMap.getResults()) { 892 if (!expr.isPureAffine()) 893 return failure(); 894 } 895 fusedIndexMaps[en.index()] = modifiedMap; 896 897 // Further check that the resulting index maps can be fused and 898 // inverted. Without this the resultant op is not legal. 899 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 900 return rewriter.notifyMatchFailure( 901 genericOp, "fused op loop bound computation failed"); 902 } 903 904 rewriter.startRootUpdate(genericOp); 905 genericOp->setOperands(fusedOperands); 906 genericOp.indexing_mapsAttr( 907 rewriter.getAffineMapArrayAttr(fusedIndexMaps)); 908 rewriter.finalizeRootUpdate(genericOp); 909 return success(); 910 } 911 return failure(); 912 } 913 }; 914 915 static SmallVector<ReassociationIndices> 916 getReassociationIndices(ArrayRef<AffineMap> maps) { 917 SmallVector<ReassociationIndices> reassociation; 918 for (AffineMap map : maps) { 919 ReassociationIndices indices; 920 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 921 unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition(); 922 indices.push_back(pos); 923 } 924 reassociation.push_back(indices); 925 } 926 return reassociation; 927 } 928 929 /// Pattern to move rank reducing reshape after an elementwise linalg generic 930 /// op. This is useful to expose more fusion opportunities between named ops and 931 /// generic ops. This can only be done if there is no broadcast or permuation 932 /// within the dimensions we need to merge. 933 /// 934 /// For example, 935 /// 936 /// %0 = tensor.expand_shape %A [[0, 1], [2]] 937 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 938 /// %2 = linalg.generic {indexing_maps = [ 939 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 940 /// affine_map<(d0, d1, d2) -> (d2)>, 941 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = 942 /// ["parallel", "parallel", "parallel"]} { 943 /// } -> tensor<112x112x16xf32> 944 /// 945 /// into 946 /// 947 /// %2 = linalg.generic {indexing_maps = [ 948 /// affine_map<(d0, d1) -> (d0, d1)>, 949 /// affine_map<(d0, d1) -> (d1)>, 950 /// affine_map<(d0, d1) -> (d0, d1)>], 951 /// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 952 /// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) { 953 /// } -> tensor<12544x16xf32> 954 /// %3 = tensor.expand_shape %2 [[0, 1], [2]] 955 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 956 struct PushExpandingReshape : public OpRewritePattern<GenericOp> { 957 using OpRewritePattern<GenericOp>::OpRewritePattern; 958 959 LogicalResult matchAndRewrite(GenericOp genericOp, 960 PatternRewriter &rewriter) const override { 961 // Only apply to elementwise linalg on tensor. 962 if (!genericOp.hasTensorSemantics() || 963 genericOp.getNumParallelLoops() != genericOp.getNumLoops()) 964 return failure(); 965 // Only support identity output maps. It could be extended to permuations if 966 // needed. 967 if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) { 968 return !genericOp.getTiedIndexingMap(opOperand).isIdentity(); 969 })) 970 return failure(); 971 int64_t destRank = genericOp.getNumParallelLoops(); 972 SmallVector<Value> newOperands = genericOp.getInputOperands(); 973 tensor::ExpandShapeOp reshapeFound; 974 // 1. Look for tensor_expand_shape operands and figure out save the 975 // dimensions merged. 976 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 977 for (auto en : llvm::enumerate(inputOperands)) { 978 auto reshapeOp = 979 en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>(); 980 if (!reshapeOp) 981 continue; 982 // TODO: We could support non-identity map as long as the merged 983 // dimensions are still contiguous. 984 if (!genericOp.getTiedIndexingMap(en.value()).isIdentity()) 985 continue; 986 if (reshapeFound) { 987 // Only support a second reshape op if it has the same reassociate maps. 988 if (reshapeFound.getReassociationMaps() == 989 reshapeOp.getReassociationMaps()) 990 newOperands[en.index()] = reshapeOp.src(); 991 continue; 992 } 993 reshapeFound = reshapeOp; 994 newOperands[en.index()] = reshapeOp.src(); 995 } 996 if (!reshapeFound) 997 return failure(); 998 999 // Calculate the reassociation indices and rassociated reverse map. 1000 SmallVector<ReassociationIndices> reassociation = 1001 getReassociationIndices(reshapeFound.getReassociationMaps()); 1002 SmallVector<unsigned> remap(destRank); 1003 for (auto &indices : llvm::enumerate(reassociation)) { 1004 for (int64_t index : indices.value()) { 1005 remap[index] = indices.index(); 1006 } 1007 } 1008 // 2. Verify that we can merge the dimensions in the linalg and that we 1009 // don't need to create new reshapes operands. Inserting new reshape 1010 // operands would defeat the purpose of the transformation. 1011 for (auto en : llvm::enumerate(inputOperands)) { 1012 if (en.value()->get() == newOperands[en.index()]) { 1013 AffineMap map = genericOp.getTiedIndexingMap(en.value()); 1014 for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { 1015 if (reassociation[remap[map.getDimPosition(i)]].size() > 1) 1016 return failure(); 1017 } 1018 } 1019 } 1020 1021 // 3. Calculate the affine map remapping and the reassociation to apply to 1022 // output tensors. 1023 SmallVector<AffineMap> newMaps; 1024 unsigned newRank = reassociation.size(); 1025 for (auto map : genericOp.getIndexingMaps()) { 1026 SmallVector<AffineExpr> newExprs; 1027 for (auto expr : map.getResults()) { 1028 unsigned position = expr.template cast<AffineDimExpr>().getPosition(); 1029 // Skip dimension merged except for the last of the group. 1030 if (reassociation[remap[position]].back() == position) { 1031 newExprs.push_back( 1032 getAffineDimExpr(remap[position], genericOp.getContext())); 1033 } 1034 } 1035 newMaps.push_back( 1036 AffineMap::get(newRank, 0, newExprs, genericOp.getContext())); 1037 } 1038 1039 // 4. Reshape the output tensors. 1040 SmallVector<Value> newOutputs; 1041 SmallVector<Type> newOutputTypes; 1042 for (auto output : genericOp.outputs()) { 1043 auto newOutputType = RankedTensorType::get( 1044 reshapeFound.getSrcType().getShape(), 1045 output.getType().template cast<RankedTensorType>().getElementType()); 1046 Value newOutput = rewriter.create<tensor::CollapseShapeOp>( 1047 genericOp->getLoc(), newOutputType, output, reassociation); 1048 newOutputTypes.push_back(newOutputType); 1049 newOutputs.push_back(newOutput); 1050 } 1051 // 5. Create a new generic op with lowerer rank. 1052 SmallVector<StringRef> iteratorTypes(newRank, 1053 getParallelIteratorTypeName()); 1054 auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes, 1055 newOperands, newOutputs, newMaps, 1056 iteratorTypes); 1057 rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), 1058 newOp.region().begin()); 1059 // 6. Reshape the so that the type matches the uses. 1060 SmallVector<Value> newResults; 1061 for (auto result : llvm::enumerate(newOp->getResults())) { 1062 newResults.push_back(rewriter.create<tensor::ExpandShapeOp>( 1063 genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()], 1064 result.value(), reassociation)); 1065 } 1066 rewriter.replaceOp(genericOp, newResults); 1067 return success(); 1068 } 1069 }; 1070 1071 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op, 1072 /// when the reshape op is collapsing dimensions. The dimensionality of the loop 1073 /// in the consumer is expanded. 1074 class FoldWithProducerReshapeOpByExpansion 1075 : public OpRewritePattern<GenericOp> { 1076 public: 1077 FoldWithProducerReshapeOpByExpansion( 1078 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1079 PatternBenefit benefit = 1) 1080 : OpRewritePattern<GenericOp>(context, benefit), 1081 controlFoldingReshapes(foldReshapes) {} 1082 1083 LogicalResult matchAndRewrite(GenericOp genericOp, 1084 PatternRewriter &rewriter) const override { 1085 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 1086 tensor::CollapseShapeOp reshapeOp = 1087 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>(); 1088 if (!reshapeOp) 1089 continue; 1090 // Fold only if 1091 // - The tensor reshape op is folding. 1092 // - All constraints of fusing with reshape by expansion are met. 1093 if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || 1094 (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) 1095 continue; 1096 1097 Optional<SmallVector<Value>> replacementValues = 1098 fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); 1099 if (!replacementValues) 1100 return failure(); 1101 rewriter.replaceOp(genericOp, replacementValues.getValue()); 1102 return success(); 1103 } 1104 return failure(); 1105 } 1106 1107 private: 1108 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1109 }; 1110 1111 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its 1112 /// producer. The corresponding index map in the consumer needs to be modified 1113 /// to linearize the folded dimension. 1114 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 1115 struct FoldConsumerReshapeOpByLinearization 1116 : public OpRewritePattern<TensorReshapeOp> { 1117 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 1118 1119 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 1120 PatternRewriter &rewriter) const override { 1121 GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>(); 1122 if (!producer || !producer.hasTensorSemantics() || 1123 producer.getNumOutputs() != 1 || 1124 !isTensorReshapeOpFoldableByLinearization( 1125 reshapeOp, 1126 producer.getTiedIndexingMap(producer.getOutputOperand(0)), 1127 /*asProducer =*/false) || 1128 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 1129 return failure(); 1130 // The indexing_maps for the operands of the fused operation are same as 1131 // those for the operands of the producer. 1132 SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps(); 1133 1134 auto invMap = inversePermutation( 1135 producer.getTiedIndexingMap(producer.getOutputOperand(0))); 1136 1137 // Compute the indexing map to use for the operand of the producer. 1138 AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp); 1139 for (AffineExpr expr : modifiedMap.getResults()) { 1140 if (!expr.isPureAffine()) { 1141 return rewriter.notifyMatchFailure( 1142 producer, "fused op indexing map is not affine"); 1143 } 1144 } 1145 fusedIndexMaps.back() = modifiedMap; 1146 1147 // Further check that the resulting index maps can be fused and 1148 // inverted. Without this the resultant op is not legal. 1149 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1150 return rewriter.notifyMatchFailure( 1151 producer, "fused op loop bound computation failed"); 1152 } 1153 1154 Location loc = producer.getLoc(); 1155 SmallVector<Value> inputOperands = producer.getInputOperands(); 1156 Value output = rewriter.create<TensorReshapeOp>( 1157 loc, producer.getOutputOperand(0)->get(), 1158 reshapeOp.getReassociationExprs()); 1159 auto fusedOp = rewriter.create<GenericOp>( 1160 loc, reshapeOp.getResultType(), 1161 /*inputs=*/inputOperands, 1162 // TODO: handle outputs. 1163 /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1164 producer.iterator_types(), 1165 /*doc=*/nullptr, 1166 /*library_call=*/nullptr); 1167 auto &fusedRegion = fusedOp->getRegion(0); 1168 rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion, 1169 fusedRegion.begin()); 1170 rewriter.replaceOp(reshapeOp, fusedOp->getResults()); 1171 return success(); 1172 } 1173 }; 1174 1175 /// Pattern to fold a tensor_expand_shape op with its producer generic op 1176 /// by expanding the dimensionality of the loop in the producer op. 1177 struct FoldReshapeWithGenericOpByExpansion 1178 : public OpRewritePattern<tensor::ExpandShapeOp> { 1179 1180 FoldReshapeWithGenericOpByExpansion( 1181 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1182 PatternBenefit benefit = 1) 1183 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), 1184 controlFoldingReshapes(foldReshapes) {} 1185 1186 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, 1187 PatternRewriter &rewriter) const override { 1188 // Fold only if all constraints of fusing with reshape by expansion are met. 1189 GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>(); 1190 if (!producer || producer.getNumOutputs() != 1 || 1191 !isFusableWithReshapeByDimExpansion(producer, 1192 producer.getOutputOperand(0)) || 1193 !controlFoldingReshapes(producer->getResult(0), 1194 reshapeOp->getOpOperand(0))) 1195 return failure(); 1196 Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion( 1197 producer, reshapeOp, producer.getOutputOperand(0), rewriter); 1198 if (!replacementValues) 1199 return failure(); 1200 rewriter.replaceOp(reshapeOp, replacementValues.getValue()); 1201 return success(); 1202 } 1203 1204 private: 1205 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1206 }; 1207 1208 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not 1209 /// handle cases where the constant is not single-valued. 1210 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> { 1211 public: 1212 FoldScalarOrSplatConstant(MLIRContext *context, 1213 ControlElementwiseOpsFusionFn &fun, 1214 PatternBenefit benefit = 1) 1215 : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {} 1216 1217 LogicalResult matchAndRewrite(GenericOp genericOp, 1218 PatternRewriter &rewriter) const override { 1219 if (!genericOp.hasTensorSemantics()) 1220 return failure(); 1221 for (OpOperand *opOperand : genericOp.getInputOperands()) { 1222 Operation *def = opOperand->get().getDefiningOp(); 1223 Attribute constantAttr; 1224 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { 1225 { 1226 DenseElementsAttr splatAttr; 1227 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) && 1228 splatAttr.isSplat() && 1229 splatAttr.getType().getElementType().isIntOrFloat()) { 1230 constantAttr = splatAttr.getSplatValue<Attribute>(); 1231 return true; 1232 } 1233 } 1234 { 1235 IntegerAttr intAttr; 1236 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) { 1237 constantAttr = intAttr; 1238 return true; 1239 } 1240 } 1241 { 1242 FloatAttr floatAttr; 1243 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) { 1244 constantAttr = floatAttr; 1245 return true; 1246 } 1247 } 1248 return false; 1249 }; 1250 1251 auto resultValue = opOperand->get().dyn_cast<OpResult>(); 1252 if (!def || !resultValue || !isScalarOrSplatConstantOp(def) || 1253 !controlFn(resultValue, *opOperand)) 1254 continue; 1255 1256 // The operands and the indexing_maps of the fused operation the same as 1257 // the operands and indexing_maps of the generic operations with the 1258 // values at the constant index dropped. 1259 SmallVector<AffineMap> fusedIndexMaps; 1260 SmallVector<Value> fusedOperands; 1261 SmallVector<Location> fusedLocs{genericOp.getLoc()}; 1262 fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); 1263 fusedOperands.reserve(genericOp.getNumInputs()); 1264 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); 1265 for (OpOperand *inputOperand : genericOp.getInputOperands()) { 1266 if (inputOperand == opOperand) 1267 continue; 1268 Value inputValue = inputOperand->get(); 1269 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); 1270 fusedOperands.push_back(inputValue); 1271 fusedLocs.push_back(inputValue.getLoc()); 1272 } 1273 for (OpOperand *outputOperand : genericOp.getOutputOperands()) 1274 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); 1275 1276 // Check if the operation shapes to loops map is computable. 1277 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1278 return rewriter.notifyMatchFailure( 1279 genericOp, "fused op loop bound computation failed"); 1280 } 1281 1282 // Create a constant scalar value from the splat constant. 1283 Value scalarConstant = rewriter.create<arith::ConstantOp>( 1284 def->getLoc(), constantAttr, constantAttr.getType()); 1285 1286 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 1287 auto fusedOp = rewriter.create<GenericOp>( 1288 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), 1289 /*inputs=*/fusedOperands, 1290 /*outputs=*/outputOperands, 1291 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1292 genericOp.iterator_types(), 1293 /*doc=*/nullptr, 1294 /*library_call=*/nullptr); 1295 1296 // Map the block argument corresponding to the replaced argument with the 1297 // scalar constant. 1298 Region ®ion = genericOp->getRegion(0); 1299 Block &entryBlock = *region.begin(); 1300 BlockAndValueMapping mapping; 1301 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), 1302 scalarConstant); 1303 Region &fusedRegion = fusedOp->getRegion(0); 1304 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), 1305 mapping); 1306 rewriter.replaceOp(genericOp, fusedOp->getResults()); 1307 return success(); 1308 } 1309 return failure(); 1310 } 1311 1312 private: 1313 ControlElementwiseOpsFusionFn controlFn; 1314 }; 1315 1316 /// Base class for constant folding linalg.generic ops with N inputs, 1 output, 1317 /// and permutation indexing maps. 1318 /// 1319 /// `ConcreteType` should provide methods with signatures 1320 /// 1321 /// ```c++ 1322 /// bool matchIndexingMaps(GenericOp genericOp) const; 1323 /// RegionComputationFn getRegionComputeFn(GenericOp) const; 1324 /// ``` 1325 /// 1326 /// The latter inspects the region and returns the computation inside as a 1327 /// functor. The functor will be invoked with constant elements for all inputs 1328 /// and should return the corresponding computea constant element for output. 1329 template <typename ConcreteType> 1330 class FoldConstantBase : public OpRewritePattern<GenericOp> { 1331 public: 1332 struct APIntOrFloat { 1333 Optional<APInt> apInt; 1334 Optional<APFloat> apFloat; 1335 }; 1336 struct APIntOrFloatArray { 1337 SmallVector<APInt> apInts; 1338 SmallVector<APFloat> apFloats; 1339 }; 1340 using RegionComputationFn = 1341 std::function<APIntOrFloat(const APIntOrFloatArray &)>; 1342 1343 FoldConstantBase(MLIRContext *context, 1344 const ControlElementwiseOpsFusionFn &controlFn, 1345 PatternBenefit benefit = 1) 1346 : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {} 1347 1348 LogicalResult matchAndRewrite(GenericOp genericOp, 1349 PatternRewriter &rewriter) const override { 1350 if (genericOp.hasBufferSemantics()) 1351 return failure(); 1352 1353 // Only support ops generating one output for now. 1354 if (genericOp.getNumOutputs() != 1) 1355 return failure(); 1356 1357 auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>(); 1358 // Require the output types to be static give we are generating constants. 1359 if (!outputType || !outputType.hasStaticShape()) 1360 return failure(); 1361 1362 if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) { 1363 return operand->get().getType().isa<ShapedType>(); 1364 })) 1365 return failure(); 1366 1367 // Make sure all element types are the same. 1368 auto getOperandElementType = [](OpOperand *operand) { 1369 return operand->get().getType().cast<ShapedType>().getElementType(); 1370 }; 1371 if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(), 1372 getOperandElementType))) 1373 return failure(); 1374 1375 // We can only handle the case where we have int/float elements. 1376 auto elementType = outputType.getElementType(); 1377 if (!elementType.isIntOrFloat()) 1378 return failure(); 1379 1380 // Require all indexing maps to be permutations for now. This is common and 1381 // it simplifies input/output access greatly: we can do the data shuffling 1382 // entirely in the compiler, without needing to turn all indices into 1383 // Values, and then do affine apply on them, and then match back the 1384 // constant again. 1385 if (!llvm::all_of(genericOp.getIndexingMaps(), 1386 [](AffineMap map) { return map.isPermutation(); })) 1387 return failure(); 1388 1389 for (OpOperand *operand : genericOp.getOutputOperands()) { 1390 if (genericOp.payloadUsesValueFromOperand(operand)) 1391 return failure(); 1392 } 1393 1394 // Further check the indexing maps are okay for the ConcreteType. 1395 if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp)) 1396 return failure(); 1397 1398 // Defer to the concrete type to check the region and discover the 1399 // computation inside. 1400 RegionComputationFn computeFn = 1401 static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp); 1402 if (!computeFn) 1403 return failure(); 1404 1405 // All inputs should be constants. 1406 int numInputs = genericOp.getNumInputs(); 1407 SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs); 1408 for (auto operand : llvm::enumerate(genericOp.getInputOperands())) { 1409 if (!matchPattern(operand.value()->get(), 1410 m_Constant(&inputValues[operand.index()]))) 1411 return failure(); 1412 } 1413 1414 // Identified this as a potential candidate for folding. Now check the 1415 // policy to see whether we are allowed to proceed. 1416 for (int i = 0; i < numInputs; ++i) { 1417 OpOperand *consumer = genericOp.getInputOperand(i); 1418 OpResult producer = consumer->get().cast<OpResult>(); 1419 if (!controlFn(producer, *consumer)) 1420 return failure(); 1421 } 1422 1423 auto linalgOp = cast<LinalgOp>(genericOp.getOperation()); 1424 SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes(); 1425 int64_t numElements = outputType.getNumElements(); 1426 1427 // Use APInt/APFloat instead of Attribute here for constructing the output. 1428 // This helps to avoid blowing up compiler memory usage: Attributes would 1429 // unify the following cases but they have lifetime as the MLIRContext. 1430 SmallVector<APInt> intOutputValues; 1431 SmallVector<APFloat> fpOutputValues; 1432 if (elementType.template isa<FloatType>()) 1433 fpOutputValues.resize(numElements, APFloat(0.f)); 1434 else 1435 intOutputValues.resize(numElements); 1436 1437 // Return the constant dim positions from the given permutation map. 1438 auto getDimPositions = [](AffineMap map) { 1439 SmallVector<unsigned> dims; 1440 dims.reserve(map.getNumResults()); 1441 for (AffineExpr result : map.getResults()) { 1442 dims.push_back(result.cast<AffineDimExpr>().getPosition()); 1443 } 1444 return dims; 1445 }; 1446 1447 SmallVector<SmallVector<unsigned>> inputDims; 1448 for (int i = 0; i < numInputs; ++i) 1449 inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i])); 1450 auto outputDims = getDimPositions(genericOp.getIndexingMaps().back()); 1451 auto outputShape = outputType.getShape(); 1452 1453 // Allocate small vectors for index delinearization. Initial values do not 1454 // matter here as they will be overwritten later. 1455 SmallVector<uint64_t> indices(loopBounds.size(), 0); 1456 SmallVector<uint64_t> dstIndices(loopBounds.size(), 0); 1457 SmallVector<SmallVector<uint64_t>> srcIndices( 1458 numInputs, SmallVector<uint64_t>(loopBounds.size(), 0)); 1459 SmallVector<uint64_t> srcLinearIndices(numInputs, 0); 1460 uint64_t dstLinearIndex = 0; 1461 1462 // Allocate spaces for compute function inputs. Initial values do not matter 1463 // here as they will be overwritten later. 1464 APIntOrFloatArray computeFnInputs; 1465 1466 auto inputShapes = llvm::to_vector<4>( 1467 llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { 1468 return operand->get().getType().cast<ShapedType>().getShape(); 1469 })); 1470 1471 // Given a `linearIndex`, remap it to a linear index to access linalg op 1472 // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, 1473 // `srcLinearIndices`, `dstLinearIndex` in place. 1474 auto computeRemappedLinearIndex = [&](int linearIndex) { 1475 int totalCount = linearIndex; 1476 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 1477 indices[dim] = totalCount % loopBounds[dim]; 1478 totalCount /= loopBounds[dim]; 1479 } 1480 1481 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 1482 for (int i = 0; i < numInputs; ++i) 1483 srcIndices[i][dim] = indices[inputDims[i][dim]]; 1484 dstIndices[dim] = indices[outputDims[dim]]; 1485 } 1486 1487 dstLinearIndex = dstIndices.front(); 1488 for (int i = 0; i < numInputs; ++i) 1489 srcLinearIndices[i] = srcIndices[i].front(); 1490 1491 for (int dim = 1; dim < outputType.getRank(); ++dim) { 1492 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; 1493 for (int i = 0; i < numInputs; ++i) 1494 srcLinearIndices[i] = 1495 srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; 1496 } 1497 }; 1498 1499 bool isFloat = elementType.isa<FloatType>(); 1500 if (isFloat) { 1501 SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges; 1502 for (int i = 0; i < numInputs; ++i) 1503 inFpRanges.push_back(inputValues[i].getValues<APFloat>()); 1504 1505 computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); 1506 1507 // Transpose the input constant. Because we don't know its rank in 1508 // advance, we need to loop over the range [0, element count) and 1509 // delinearize the index. 1510 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 1511 computeRemappedLinearIndex(linearIndex); 1512 1513 // Collect constant elements for all inputs at this loop iteration. 1514 for (int i = 0; i < numInputs; ++i) 1515 computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; 1516 1517 // Invoke the computation to get the corresponding constant output 1518 // element. 1519 fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; 1520 } 1521 } else { 1522 SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges; 1523 for (int i = 0; i < numInputs; ++i) 1524 inIntRanges.push_back(inputValues[i].getValues<APInt>()); 1525 1526 computeFnInputs.apInts.resize(numInputs); 1527 1528 // Transpose the input constant. Because we don't know its rank in 1529 // advance, we need to loop over the range [0, element count) and 1530 // delinearize the index. 1531 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 1532 computeRemappedLinearIndex(linearIndex); 1533 1534 // Collect constant elements for all inputs at this loop iteration. 1535 for (int i = 0; i < numInputs; ++i) 1536 computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; 1537 1538 // Invoke the computation to get the corresponding constant output 1539 // element. 1540 intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; 1541 } 1542 } 1543 1544 DenseElementsAttr outputAttr = 1545 isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) 1546 : DenseElementsAttr::get(outputType, intOutputValues); 1547 1548 rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr); 1549 return success(); 1550 } 1551 1552 private: 1553 ControlElementwiseOpsFusionFn controlFn; 1554 }; 1555 1556 // Folds linalg.generic ops that are actually transposes on constant values. 1557 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> { 1558 using FoldConstantBase::FoldConstantBase; 1559 1560 bool matchIndexingMaps(GenericOp genericOp) const { 1561 // We should have one input and one output. 1562 return genericOp.getIndexingMaps().size() == 2; 1563 } 1564 1565 RegionComputationFn getRegionComputeFn(GenericOp genericOp) const { 1566 // Make sure the region only contains a yield op. 1567 Block &body = genericOp.region().front(); 1568 if (!llvm::hasSingleElement(body)) 1569 return nullptr; 1570 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); 1571 if (!yieldOp) 1572 return nullptr; 1573 1574 // The yield op should return the block argument corresponds to the input. 1575 for (Value yieldVal : yieldOp.values()) { 1576 auto yieldArg = yieldVal.dyn_cast<BlockArgument>(); 1577 if (!yieldArg || yieldArg.getOwner() != &body) 1578 return nullptr; 1579 if (yieldArg.getArgNumber() != 0) 1580 return nullptr; 1581 } 1582 1583 // No computation; just return the orginal value. 1584 return [](const APIntOrFloatArray &inputs) { 1585 if (inputs.apFloats.empty()) 1586 return APIntOrFloat{inputs.apInts.front(), llvm::None}; 1587 return APIntOrFloat{llvm::None, inputs.apFloats.front()}; 1588 }; 1589 } 1590 1591 ControlElementwiseOpsFusionFn controlFn; 1592 }; 1593 1594 } // namespace 1595 1596 static Optional<SmallVector<Value>> 1597 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, 1598 GenericOp producer, 1599 const ControlElementwiseOpsFusionFn &controlFn) { 1600 if (producer->getNumResults() != 1) 1601 return llvm::None; 1602 1603 return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, 1604 rewriter); 1605 } 1606 1607 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, 1608 OpOperand &consumer) { 1609 if (auto producerCollapseOp = 1610 dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) { 1611 return !isUnitDimExpansionOnly(producerCollapseOp); 1612 } 1613 if (auto consumerExpandOp = 1614 dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) { 1615 return !isUnitDimExpansionOnly(consumerExpandOp); 1616 } 1617 return true; 1618 } 1619 1620 namespace { 1621 /// Patterns to fuse a generic op, with the producer of its operands. 1622 class FuseElementwiseOps : public OpRewritePattern<GenericOp> { 1623 public: 1624 FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, 1625 PatternBenefit benefit = 1) 1626 : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {} 1627 1628 LogicalResult matchAndRewrite(GenericOp genericOp, 1629 PatternRewriter &rewriter) const override { 1630 // Find the first operand that is defined by another generic op on tensors. 1631 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 1632 auto producer = 1633 dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp()); 1634 if (!producer || !producer.hasTensorSemantics()) 1635 continue; 1636 Optional<SmallVector<Value>> fusedOpResults = 1637 fuseElementwiseOps(rewriter, opOperand, producer, controlFn); 1638 if (fusedOpResults) { 1639 rewriter.replaceOp(genericOp, *fusedOpResults); 1640 return success(); 1641 } 1642 } 1643 return failure(); 1644 } 1645 1646 private: 1647 ControlElementwiseOpsFusionFn controlFn; 1648 }; 1649 1650 /// Pass that fuses generic ops on tensors. Used only for testing. 1651 struct LinalgElementwiseOpFusionPass 1652 : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> { 1653 void runOnOperation() override { 1654 Operation *op = getOperation(); 1655 RewritePatternSet patterns(op->getContext()); 1656 ControlElementwiseOpsFusionFn allowFoldingFn = 1657 [](const OpResult &producer, const OpOperand &consumer) { 1658 return true; 1659 }; 1660 populateElementwiseOpsFusionPatterns( 1661 patterns, 1662 LinalgElementwiseFusionOptions().setControlFoldingReshapes( 1663 allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape)); 1664 1665 // Use TopDownTraversal for compile time reasons 1666 GreedyRewriteConfig grc; 1667 grc.useTopDownTraversal = true; 1668 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), 1669 grc); 1670 } 1671 }; 1672 1673 /// Pass to test folding of reshape ops with generic ops by linearization. 1674 struct FoldReshapeOpsByLinearizationPass 1675 : public LinalgFoldReshapeOpsByLinearizationBase< 1676 FoldReshapeOpsByLinearizationPass> { 1677 void runOnOperation() override { 1678 Operation *op = getOperation(); 1679 RewritePatternSet patterns(op->getContext()); 1680 populateFoldReshapeOpsByLinearizationPatterns(patterns); 1681 if (allowFoldingUnitDimReshapes) { 1682 populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns); 1683 } 1684 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); 1685 } 1686 }; 1687 1688 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if 1689 /// the value of the `outs` operand is not used within the op. This is only 1690 /// implemented for `linalg.generic` operations for now, but should hold for all 1691 /// linalg structured ops. 1692 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> { 1693 using OpRewritePattern<GenericOp>::OpRewritePattern; 1694 1695 LogicalResult matchAndRewrite(GenericOp op, 1696 PatternRewriter &rewriter) const override { 1697 rewriter.startRootUpdate(op); 1698 bool modifiedOutput = false; 1699 Location loc = op.getLoc(); 1700 for (OpOperand *opOperand : op.getOutputOperands()) { 1701 if (!op.payloadUsesValueFromOperand(opOperand)) { 1702 Value operandVal = opOperand->get(); 1703 auto operandType = operandVal.getType().dyn_cast<RankedTensorType>(); 1704 if (!operandType) 1705 continue; 1706 1707 // If outs is already an `init_tensor` operation, nothing to do. 1708 auto definingOp = operandVal.getDefiningOp<InitTensorOp>(); 1709 if (definingOp) 1710 continue; 1711 modifiedOutput = true; 1712 SmallVector<Value> dynamicDims; 1713 for (auto dim : llvm::enumerate(operandType.getShape())) { 1714 if (dim.value() != ShapedType::kDynamicSize) 1715 continue; 1716 dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>( 1717 loc, operandVal, dim.index())); 1718 } 1719 Value initTensor = rewriter.create<InitTensorOp>( 1720 loc, dynamicDims, operandType.getShape(), 1721 operandType.getElementType()); 1722 op->setOperand(opOperand->getOperandNumber(), initTensor); 1723 } 1724 } 1725 if (!modifiedOutput) { 1726 rewriter.cancelRootUpdate(op); 1727 return failure(); 1728 } 1729 rewriter.finalizeRootUpdate(op); 1730 return success(); 1731 } 1732 }; 1733 1734 } // namespace 1735 1736 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( 1737 RewritePatternSet &patterns) { 1738 patterns 1739 .add<FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>, 1740 FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>, 1741 FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>, 1742 FoldConsumerReshapeOpByLinearization<false, tensor::ExpandShapeOp>>( 1743 patterns.getContext()); 1744 } 1745 1746 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( 1747 RewritePatternSet &patterns) { 1748 patterns 1749 .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>, 1750 FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>, 1751 FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>, 1752 FoldConsumerReshapeOpByLinearization<true, tensor::ExpandShapeOp>>( 1753 patterns.getContext()); 1754 } 1755 1756 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( 1757 RewritePatternSet &patterns, 1758 ControlElementwiseOpsFusionFn controlFoldingReshapes) { 1759 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(), 1760 controlFoldingReshapes); 1761 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), 1762 controlFoldingReshapes); 1763 } 1764 1765 void mlir::linalg::populateElementwiseOpsFusionPatterns( 1766 RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { 1767 auto *context = patterns.getContext(); 1768 patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant, 1769 FoldConstantTranspose>(context, 1770 options.controlElementwiseOpsFusionFn); 1771 patterns.add<RemoveOutsDependency>(context); 1772 populateFoldReshapeOpsByExpansionPatterns(patterns, 1773 options.controlFoldingReshapesFn); 1774 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 1775 GenericOp::getCanonicalizationPatterns(patterns, context); 1776 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); 1777 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); 1778 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns( 1779 patterns); 1780 } 1781 1782 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { 1783 auto *context = patterns.getContext(); 1784 patterns.add<PushExpandingReshape>(context); 1785 } 1786 1787 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() { 1788 return std::make_unique<LinalgElementwiseOpFusionPass>(); 1789 } 1790 1791 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() { 1792 return std::make_unique<FoldReshapeOpsByLinearizationPass>(); 1793 } 1794