1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion on tensors operations pass. 10 // 11 //===----------------------------------------------------------------------===// 12 #include "PassDetail.h" 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/IR/AffineExpr.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Support/LLVM.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 26 using namespace mlir; 27 using namespace mlir::linalg; 28 29 /// Conditions for elementwise fusion of generic operations. 30 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, 31 OpOperand *consumerOpOperand) { 32 // Producer and consumer must have tensor semantics. 33 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 34 return false; 35 36 // Verify that 37 // - the producer has all "parallel" iterator type. 38 if (producer.getNumParallelLoops() != producer.getNumLoops()) 39 return false; 40 41 // Only allow fusing the producer of an input operand for now. 42 // TODO: allow fusing the producer of an output operand. 43 if (!consumer.isInputTensor(consumerOpOperand)) 44 return false; 45 46 // Get the consumer index map. The number of results of the consumer index 47 // map must match the number of loops of the producer. 48 AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); 49 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 50 return false; 51 52 // Currently support only operations with single result. 53 if (producer.getNumOutputs() != 1) 54 return false; 55 56 // Finally the index_map for the result must be invertible. For now just 57 // verify it is a permutation. 58 AffineMap producerResultIndexMap = 59 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 60 return producerResultIndexMap.isPermutation(); 61 } 62 63 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 64 /// the `producer` to use in the fused operation given the indexing map of the 65 /// result of the producer in the consumer. 66 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 67 OpOperand *producerOpOperand, AffineMap producerResultIndexMap, 68 AffineMap fusedConsumerArgIndexMap) { 69 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 70 // from consumer loop -> consumer arg tensor index/producer result tensor 71 // index. The fused loop is same as the consumer loop. For each producer arg 72 // the indexing map to be computed is a map from consumer loop -> producer 73 // arg tensor index. 74 // producerResultIndexMap is a map from producer loop -> tensor index. 75 // Compute the inverse to get map from tensor index -> producer loop. 76 // The inverse is a map from producer result tensor index -> producer loop. 77 AffineMap invProducerResultIndexMap = 78 inversePermutation(producerResultIndexMap); 79 assert(invProducerResultIndexMap && 80 "expected producer result indexig map to be invertible"); 81 82 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner()); 83 // argMap is a map from producer loop -> producer arg tensor index. 84 AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); 85 86 // Compose argMap with invProducerResultIndexMap to get a map from 87 // producer result tensor index -> producer arg tensor index. 88 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 89 90 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 91 // consumer loop/ fused loop -> producer arg tensor index. 92 return t1.compose(fusedConsumerArgIndexMap); 93 } 94 95 /// Generate the region of the fused tensor operation. The region of the fused 96 /// op must be empty. 97 static void 98 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, 99 AffineMap consumerToProducerLoopsMap, 100 OpOperand *consumerOpOperand, 101 unsigned nloops) { 102 auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp()); 103 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 104 // Build the region of the fused op. 105 Block &producerBlock = producer->getRegion(0).front(); 106 Block &consumerBlock = consumer->getRegion(0).front(); 107 Block *fusedBlock = new Block(); 108 fusedOp.region().push_back(fusedBlock); 109 BlockAndValueMapping mapper; 110 OpBuilder::InsertionGuard guard(rewriter); 111 rewriter.setInsertionPointToStart(fusedBlock); 112 113 // 2. Add an index operation for every fused loop dimension and use the 114 // `consumerToProducerLoopsMap` to map the producer indices. 115 if (producer.hasIndexSemantics()) { 116 // Add an index operation for every fused loop dimension. 117 unsigned numFusedOpLoops = 118 std::max(producer.getNumLoops(), consumer.getNumLoops()); 119 SmallVector<Value> fusedIndices; 120 fusedIndices.reserve(numFusedOpLoops); 121 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops), 122 std::back_inserter(fusedIndices), [&](uint64_t dim) { 123 return rewriter.create<IndexOp>(producer.getLoc(), dim); 124 }); 125 for (IndexOp indexOp : 126 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) { 127 Value newIndex = rewriter.create<mlir::AffineApplyOp>( 128 producer.getLoc(), 129 consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices); 130 mapper.map(indexOp.getResult(), newIndex); 131 } 132 } 133 // TODO: allow fusing the producer of an output operand. 134 assert(consumer.isInputTensor(consumerOpOperand) && 135 "expected producer of input operand"); 136 // 3. Consumer input operands up to consumerIdx (exclusive). 137 for (BlockArgument bbArg : consumerBlock.getArguments().take_front( 138 consumerOpOperand->getOperandNumber())) // input assumption. 139 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 140 141 // Replacing consumerIdx requires getting the cloned, yielded, value from 142 // the (cloned) producer block. This happens in step 9. 143 144 // 4. Splice in producer's input operands. 145 for (BlockArgument bbArg : 146 producerBlock.getArguments().take_front(producer.getNumInputs())) 147 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 148 149 // 4.b. Producer output operand/map that is fused needs to be mapped to the 150 // producer bbArg if it is an "initTensor" (i.e. its value is actually read). 151 assert(producer->getNumResults() == 1 && "expected single result producer"); 152 if (producer.isInitTensor(producer.getOutputOperand(0))) { 153 BlockArgument bbArg = producerBlock.getArguments() 154 .drop_front(producer.getNumInputs()) 155 // TODO: bbArg index of 156 .front(); 157 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 158 } 159 // 5. Remaining consumer's input operands (drop past index `consumerIdx`). 160 for (BlockArgument bbArg : 161 consumerBlock.getArguments() 162 .take_front(consumer.getNumInputs()) 163 .drop_front(consumerOpOperand->getOperandNumber() + 1)) 164 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 165 // 6. All of consumer's output operands. 166 for (BlockArgument bbArg : 167 consumerBlock.getArguments().take_back(consumer.getNumOutputs())) 168 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); 169 // 7. All of producer's output operands except the one fused. 170 // TODO: allow fusion of multi-result producers. 171 assert(producer->getNumResults() == 1 && "expected single result producer"); 172 173 // 8. Clone all producer operations except for the yield and index operations 174 // to the fused operation. 175 for (auto &op : producerBlock.without_terminator()) { 176 if (!isa<IndexOp>(op)) 177 rewriter.clone(op, mapper); 178 } 179 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just 180 // forward the yield operand. 181 auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator()); 182 // TODO: allow fusion of multi-result producers. 183 assert(producer->getNumResults() == 1 && "expected single result producer"); 184 unsigned producerResultNumber = 0; 185 Value replacement = 186 mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); 187 // Sanity checks, if replacement is not already in the mapper then it must be 188 // produced outside. 189 if (replacement == yieldOp.getOperand(producerResultNumber)) { 190 if (auto bb = replacement.dyn_cast<BlockArgument>()) 191 assert(bb.getOwner() != &producerBlock && 192 "yielded block argument must have been mapped"); 193 else 194 assert(!producer->isAncestor(replacement.getDefiningOp()) && 195 "yielded value must have been mapped"); 196 } 197 mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), 198 replacement); 199 // 10. Clone operations from the consumer to the fused op. 200 for (auto &op : consumerBlock.getOperations()) 201 rewriter.clone(op, mapper); 202 203 // Sanity checks. 204 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && 205 "Ill-formed GenericOp region"); 206 } 207 208 static Optional<SmallVector<Value>> 209 fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, 210 const ControlElementwiseOpsFusionFn &controlFn, 211 PatternRewriter &rewriter) { 212 auto consumer = cast<GenericOp>(consumerOpOperand->getOwner()); 213 if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || 214 !controlFn(producer->getResult(0), *consumerOpOperand)) 215 return llvm::None; 216 217 // TODO: allow fusing the producer of an output operand. 218 assert(consumer.isInputTensor(consumerOpOperand) && 219 "expected producer of input operand"); 220 221 // Compute the fused operands list and indexing maps. 222 SmallVector<Value> fusedOperands; 223 SmallVector<AffineMap> fusedIndexMaps; 224 fusedOperands.reserve(producer->getNumOperands() + 225 consumer->getNumOperands()); 226 fusedIndexMaps.reserve(producer->getNumOperands() + 227 consumer->getNumOperands()); 228 // In the following, numbering matches that of `generateFusedTensorOpRegion`. 229 // 3. Consumer input operands/maps up to consumerIdx (exclusive). 230 SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands(); 231 SmallVector<OpOperand *>::iterator it = 232 llvm::find(consumerInputs, consumerOpOperand); 233 assert(it != consumerInputs.end() && "expected to find the consumer operand"); 234 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { 235 fusedOperands.push_back(opOperand->get()); 236 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 237 } 238 // 4. Splice in producer's input operands/maps. 239 assert(producer->getNumResults() == 1 && "expected single result producer"); 240 AffineMap producerResultIndexMap = 241 producer.getTiedIndexingMap(producer.getOutputOperand(0)); 242 for (OpOperand *opOperand : producer.getInputOperands()) { 243 fusedOperands.push_back(opOperand->get()); 244 // Compute indexing maps for the producer args in the fused operation. 245 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 246 opOperand, producerResultIndexMap, 247 consumer.getTiedIndexingMap(consumerOpOperand)); 248 fusedIndexMaps.push_back(map); 249 } 250 // 4.b. Producer output operand/map that is fused needs to be passed if it is 251 // an "initTensor" (i.e. its value is actually read). 252 assert(producer->getNumResults() == 1 && "expected single result producer"); 253 if (producer.isInitTensor(producer.getOutputOperand(0))) { 254 fusedOperands.push_back(producer.getOutputOperand(0)->get()); 255 // Compute indexing maps for the producer args in the fused operation. 256 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 257 producer.getOutputOperand(0), producerResultIndexMap, 258 consumer.getTiedIndexingMap(consumerOpOperand)); 259 fusedIndexMaps.push_back(map); 260 } 261 // 5. Remaining consumer's input operands/maps (drop past index 262 // `consumerIdx`). 263 for (OpOperand *opOperand : 264 llvm::make_range(std::next(it), consumerInputs.end())) { 265 fusedOperands.push_back(opOperand->get()); 266 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 267 } 268 // 6. All of consumer's output operands (skip operands: added by the builder). 269 for (OpOperand *opOperand : consumer.getOutputOperands()) 270 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); 271 // 7. All of producer's output operands/maps except the one fused. 272 // TODO: allow fusion of multi-result producers. 273 assert(producer->getNumResults() == 1 && "expected single result producer"); 274 275 // Generate the fused op. 276 SmallVector<Value> consumerOutputs = consumer.getOutputOperands(); 277 auto fusedOp = rewriter.create<GenericOp>( 278 consumer.getLoc(), consumer->getResultTypes(), 279 /*inputs=*/fusedOperands, 280 // TODO: handle outputs. 281 consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 282 consumer.iterator_types(), 283 /*doc=*/nullptr, 284 /*library_call=*/nullptr); 285 286 // Construct an AffineMap from consumer loops to producer loops. 287 // consumer loop -> tensor index 288 AffineMap consumerResultIndexMap = 289 consumer.getTiedIndexingMap(consumerOpOperand); 290 // tensor index -> producer loop 291 AffineMap invProducerResultIndexMap = 292 inversePermutation(producerResultIndexMap); 293 assert(invProducerResultIndexMap && 294 "expected producer result indexig map to be invertible"); 295 // consumer loop -> producer loop 296 AffineMap consumerToProducerLoopsMap = 297 invProducerResultIndexMap.compose(consumerResultIndexMap); 298 299 generateFusedElementwiseOpRegion(rewriter, fusedOp, 300 consumerToProducerLoopsMap, 301 consumerOpOperand, consumer.getNumLoops()); 302 return SmallVector<Value>(fusedOp->getResults()); 303 } 304 305 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` 306 /// provided, given the shape of the source tensor that corresponds to the 307 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions 308 /// are "row-major" ordered logically. 309 /// 310 /// For example: 311 /// 312 /// %0 = op ... : tensor<?x?x4x5xf32> 313 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` 314 /// 315 /// and reshape: 316 /// %1 = linalg.tensor_collapse_shape %0 [[0], [0, 1, 2]] : 317 /// tensor<?x?x4x5xf32> into tensor<?x?xf32> 318 /// 319 /// would be rewritten into: 320 /// %0 = op ... : tensor<?x?x4x5xf32> 321 /// with output index_map 322 /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` 323 template <typename TensorReshapeOp> 324 static AffineMap linearizeCollapsedDims(AffineMap sourceMap, 325 TensorReshapeOp reshapeOp) { 326 constexpr bool isExpanding = 327 std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value; 328 ArrayRef<int64_t> sourceShape = 329 (isExpanding ? reshapeOp.getResultType().getShape() 330 : reshapeOp.getSrcType().getShape()); 331 SmallVector<AffineExpr> resultExprs; 332 ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults(); 333 MLIRContext *context = sourceMap.getContext(); 334 335 // Compute the result exprs based on the reassociation maps. 336 for (auto &indices : reshapeOp.getReassociationIndices()) { 337 // Assume that they are in-order and contiguous (already checked in 338 // verifier). 339 assert(!indices.empty()); 340 SmallVector<int64_t> sizes; 341 SmallVector<AffineExpr> dimExprs; 342 for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()), 343 sourceExprs.slice(indices[0], indices.size()))) { 344 if (std::get<0>(en) == 1) 345 continue; 346 sizes.push_back(std::get<0>(en)); 347 dimExprs.push_back(std::get<1>(en)); 348 } 349 AffineExpr linearizedExpr = 350 makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); 351 resultExprs.push_back(linearizedExpr); 352 } 353 return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(), 354 resultExprs, context); 355 } 356 357 // TensorExpandShapeOp is fusable with its consumer (i.e. reshape as a 358 // producer). Fusing when operand has higher rank will require use of mods and 359 // divs in the indexing maps of the fused op which would make it non-invertible. 360 static bool isTensorReshapeOpFoldableByLinearization( 361 TensorExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) { 362 if (!asProducer) 363 return false; 364 return useIndexMap.isPermutation(); 365 } 366 367 // TensorCollapseShapeOp is fusable with its producer (i.e. reshape as a 368 // consumer). 369 static bool isTensorReshapeOpFoldableByLinearization( 370 TensorCollapseShapeOp collapseOp, AffineMap useIndexMap, bool asProducer) { 371 if (asProducer) 372 return false; 373 return useIndexMap.isPermutation(); 374 } 375 376 /// Check if the reshape operation is only expansion into/collapsing of 377 /// unit-dimension. 378 template <typename TensorReshapeOp> 379 static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) { 380 constexpr bool isExpanding = 381 std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value; 382 ArrayRef<int64_t> expandedShape = 383 (isExpanding ? reshapeOp.getResultType().getShape() 384 : reshapeOp.getSrcType().getShape()); 385 for (auto &indices : reshapeOp.getReassociationIndices()) { 386 unsigned numUnitDims = 0; 387 for (int64_t position : indices) 388 if (expandedShape[position] == 1) 389 numUnitDims++; 390 if (numUnitDims != indices.size() - 1) 391 return false; 392 } 393 return true; 394 } 395 396 /// Conditions for folding a generic operation with a reshape op by expanding 397 /// the iteration space dimensionality for tensor operations. These are 398 /// preconditions assumed by `foldReshapeByDimExpansion` which implements the 399 /// following fusion pattern. 400 /// 401 /// Consider 402 /// 403 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>) 404 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 405 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 406 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] 407 /// %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]] 408 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 409 /// 410 /// The reshape can be folded into the `genericOp` if its loop dimensionality 411 /// is increased to match the result (operand) of the tensor_expand_shape. 412 /// The indexing_map of the fused tensor in the `genericOp` and the 413 /// reassociation map helps compute the indexing maps of the modified op. 414 /// For the above example, based on the reassociation map it 415 /// can be concluded that 416 /// 417 /// - The loop used to access the first dimension of the fused tensor is split 418 /// into two. 419 /// - The loop used to access the second dimension of the fused tensor is kept 420 /// as is. 421 /// - The loop used to access the third dimension of the fused tensor is split 422 /// into three. 423 /// 424 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified 425 /// op, then 426 /// 427 /// d0 -> e0, e1 428 /// d1 -> e2, e3, e4 429 /// d2 -> e5 430 /// 431 /// substituting this, the generic op can be rewritten as 432 /// 433 /// %d = linalg.generic ins(%0, %1 : ) 434 /// indexing_maps = 435 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, 436 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, 437 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] 438 /// 439 /// Since operands to the linalg generic are now 5D, reshapes can be introduced 440 /// to make it consistent 441 /// 442 /// %0 = linalg.tensor_expand_shape %a [[0, 1, 2], [3, 4], [5]] 443 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 444 /// %1 = linalg.tensor_expand_shape %b [[0, 1, 2], [3]] 445 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 446 /// 447 /// The added reshapes are again expanding patterns, so they will get fused 448 /// with its producers if possible. 449 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, 450 OpOperand *fusableOpOperand) { 451 // Is fusable only if: 452 // - All the indexing maps for operands and results are projected 453 // permutations. 454 // - The fused tensor is not a scalar. 455 // - All the loops are parallel loops. 456 return genericOp.hasTensorSemantics() && 457 llvm::all_of(genericOp.indexing_maps().getValue(), 458 [](Attribute attr) { 459 return attr.cast<AffineMapAttr>() 460 .getValue() 461 .isProjectedPermutation(); 462 }) && 463 genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && 464 llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { 465 return attr.cast<StringAttr>().getValue() == 466 getParallelIteratorTypeName(); 467 }); 468 } 469 470 namespace { 471 /// Information needed to expand a generic operation to fold the reshape with 472 /// it. 473 class ExpansionInfo { 474 public: 475 // Computes the mapping from original dimensions of the op to the dimensions 476 // of the expanded op given the `indexingMap` of the fused operand/result of 477 // the generic op, the `reassocationMaps` of the reshape op and the shape of 478 // the expanded op. 479 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, 480 ArrayRef<AffineMap> reassociationMaps, 481 ArrayRef<int64_t> expandedShape, 482 PatternRewriter &rewriter); 483 unsigned getOrigOpNumDims() const { return reassociation.size(); } 484 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } 485 ReassociationIndicesRef getExpandedDims(unsigned i) const { 486 return reassociation[i]; 487 } 488 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { 489 return expandedShapeMap[i]; 490 } 491 492 private: 493 /// Reassociation from the dimensions in the original operation to the 494 /// dimension of the expanded operation. 495 SmallVector<ReassociationIndices> reassociation; 496 /// Mapping from extent of loops in the original operation, to the extent of 497 /// loops in the expanded operation. 498 SmallVector<SmallVector<int64_t>> expandedShapeMap; 499 unsigned expandedOpNumDims; 500 }; 501 } // namespace 502 503 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, 504 OpOperand *fusableOpOperand, 505 ArrayRef<AffineMap> reassociationMaps, 506 ArrayRef<int64_t> expandedShape, 507 PatternRewriter &rewriter) { 508 if (reassociationMaps.empty()) 509 return failure(); 510 AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); 511 512 Optional<SmallVector<int64_t, 4>> originalLoopRange = 513 linalgOp.getStaticLoopRanges(); 514 if (!originalLoopRange) 515 return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range"); 516 517 reassociation.clear(); 518 expandedShapeMap.clear(); 519 // Compute the number of dimension in the expanded op that correspond to each 520 // dimension of the original op. 521 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1); 522 expandedShapeMap.resize(fusedIndexMap.getNumDims()); 523 for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { 524 unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition(); 525 AffineMap foldedDims = reassociationMaps[resultExpr.index()]; 526 numExpandedDims[pos] = foldedDims.getNumResults(); 527 ArrayRef<int64_t> shape = 528 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); 529 expandedShapeMap[pos].assign(shape.begin(), shape.end()); 530 } 531 // The remaining dimensions remain the same. 532 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims())) 533 if (expandedShapeMap[i].empty()) 534 expandedShapeMap[i] = {(*originalLoopRange)[i]}; 535 536 // Compute reassociation map from the original op to the expanded op. 537 unsigned sum = 0; 538 reassociation.reserve(fusedIndexMap.getNumDims()); 539 for (auto numFoldedDim : llvm::enumerate(numExpandedDims)) { 540 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value()); 541 reassociation.emplace_back(seq.begin(), seq.end()); 542 sum += numFoldedDim.value(); 543 } 544 expandedOpNumDims = sum; 545 return success(); 546 } 547 548 /// Epanding the body of a linalg operation requires adaptations of the accessed 549 /// loop indices. Specifically, access of indices in the original operation need 550 /// to be replaced with linearizations of indices in the expanded op. That 551 /// requires the shape of the expanded dimensions to be static (at least all but 552 /// the most significant). For now check that these are all statically sized. 553 /// Note that this could be extended to handle dynamic case, but the 554 /// implementation below uses `affine.apply` which seems to have issues when the 555 /// shapes are not static. 556 LogicalResult isGenericOpExpandable(GenericOp genericOp, 557 const ExpansionInfo &expansionInfo, 558 PatternRewriter &rewriter) { 559 if (!genericOp.hasIndexSemantics()) 560 return success(); 561 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 562 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 563 if (expandedShape.size() == 1) 564 continue; 565 for (int64_t shape : expandedShape.drop_front()) { 566 if (ShapedType::isDynamic(shape)) { 567 return rewriter.notifyMatchFailure( 568 genericOp, "cannot expand due to index semantics and dynamic dims"); 569 } 570 } 571 } 572 return success(); 573 } 574 575 /// Return the indexing map to use in the expanded op for a given the 576 /// `indexingMap` of the original operation. 577 static AffineMap 578 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, 579 const ExpansionInfo &expansionInfo) { 580 SmallVector<AffineExpr> newExprs; 581 for (AffineExpr expr : indexingMap.getResults()) { 582 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 583 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>( 584 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { 585 return builder.getAffineDimExpr(static_cast<unsigned>(v)); 586 })); 587 newExprs.append(expandedExprs.begin(), expandedExprs.end()); 588 } 589 return AffineMap::get(expansionInfo.getExpandedOpNumDims(), 590 indexingMap.getNumSymbols(), newExprs, 591 builder.getContext()); 592 } 593 594 /// Return the type of the operand/result to use in the expanded op given the 595 /// type in the original op. 596 static RankedTensorType getExpandedType(RankedTensorType originalType, 597 AffineMap indexingMap, 598 const ExpansionInfo &expansionInfo) { 599 SmallVector<int64_t> expandedShape; 600 for (AffineExpr expr : indexingMap.getResults()) { 601 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 602 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); 603 expandedShape.append(dimExpansion.begin(), dimExpansion.end()); 604 } 605 return RankedTensorType::get(expandedShape, originalType.getElementType()); 606 } 607 608 /// Returns the reassociation maps to use in the `linalg.tensor_expand_shape` 609 /// operation to convert the operands of the original operation to operands of 610 /// the expanded operation. The same method is used to compute the 611 /// `linalg.tensor_collapse_shape` used to collapse the result of the expanded 612 /// op to get the value that can replace all uses of the results of the original 613 /// op. 614 static SmallVector<ReassociationIndices> 615 getReassociationForExpansion(AffineMap indexingMap, 616 const ExpansionInfo &expansionInfo) { 617 SmallVector<ReassociationIndices> reassociation; 618 unsigned numReshapeDims = 0; 619 for (AffineExpr expr : indexingMap.getResults()) { 620 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 621 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); 622 SmallVector<int64_t, 2> indices = llvm::to_vector<2>( 623 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims)); 624 reassociation.emplace_back(std::move(indices)); 625 numReshapeDims += numExpandedDims; 626 } 627 return reassociation; 628 } 629 630 /// Update the body of an expanded linalg operation having index semantics. The 631 /// indices of the original operation need to be recovered by linearizing the 632 /// indices of the correspoding dimensions of the expanded operation. For now it 633 /// is assumed that the shapes of the expanded operation needed for 634 /// linearization are static. 635 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, 636 Location loc, Region &fusedRegion, 637 const ExpansionInfo &expansionInfo) { 638 // Replace the original indices by the linearization of the expanded indices. 639 for (IndexOp indexOp : 640 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) { 641 ArrayRef<int64_t> expandedDims = 642 expansionInfo.getExpandedDims(indexOp.dim()); 643 assert(!expandedDims.empty() && "expected valid expansion info"); 644 645 // Skip index operations that are not affected by the expansion. 646 if (expandedDims.size() == 1 && 647 expandedDims.front() == (int64_t)indexOp.dim()) 648 continue; 649 650 // Linearize the expanded indices of the original index dimension. 651 OpBuilder::InsertionGuard guard(rewriter); 652 rewriter.setInsertionPointAfter(indexOp); 653 ArrayRef<int64_t> expandedDimsShape = 654 expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front(); 655 SmallVector<Value> expandedIndices; 656 expandedIndices.reserve(expandedDims.size() - 1); 657 llvm::transform( 658 expandedDims.drop_front(), std::back_inserter(expandedIndices), 659 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); 660 Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front()); 661 for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { 662 assert(!ShapedType::isDynamic(std::get<0>(it))); 663 AffineExpr idx, acc; 664 bindDims(rewriter.getContext(), idx, acc); 665 newIndex = rewriter.create<AffineApplyOp>( 666 indexOp.getLoc(), idx + acc * std::get<0>(it), 667 ValueRange{std::get<1>(it), newIndex}); 668 } 669 rewriter.replaceOp(indexOp, newIndex); 670 } 671 } 672 673 /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op 674 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes 675 /// that those conditions have been satisfied. 676 static Optional<SmallVector<Value>> 677 fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, 678 OpOperand *fusableOpOperand, 679 PatternRewriter &rewriter) { 680 assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) && 681 "preconditions for fuse operation failed"); 682 // Check if reshape is expanding or collapsing. 683 auto expandingReshapeOp = dyn_cast<TensorExpandShapeOp>(*reshapeOp); 684 auto collapsingReshapeOp = dyn_cast<TensorCollapseShapeOp>(*reshapeOp); 685 bool isExpanding = (expandingReshapeOp != nullptr); 686 RankedTensorType expandedType = isExpanding 687 ? expandingReshapeOp.getResultType() 688 : collapsingReshapeOp.getSrcType(); 689 690 ExpansionInfo expansionInfo; 691 if (failed(expansionInfo.compute( 692 genericOp, fusableOpOperand, 693 isExpanding ? expandingReshapeOp.getReassociationMaps() 694 : collapsingReshapeOp.getReassociationMaps(), 695 expandedType.getShape(), rewriter))) 696 return llvm::None; 697 698 if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) 699 return llvm::None; 700 701 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( 702 llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) { 703 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); 704 })); 705 706 SmallVector<Value> expandedOpOperands; 707 expandedOpOperands.reserve(genericOp.getNumInputs()); 708 for (OpOperand *opOperand : genericOp.getInputOperands()) { 709 if (opOperand == fusableOpOperand) { 710 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() 711 : collapsingReshapeOp.src()); 712 continue; 713 } 714 if (genericOp.isInputTensor(opOperand)) { 715 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 716 RankedTensorType expandedOperandType = 717 getExpandedType(opOperand->get().getType().cast<RankedTensorType>(), 718 indexingMap, expansionInfo); 719 if (expandedOperandType != opOperand->get().getType()) { 720 // Reshape the operand to get the right type. 721 SmallVector<ReassociationIndices> reassociation = 722 getReassociationForExpansion(indexingMap, expansionInfo); 723 expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>( 724 genericOp.getLoc(), expandedOperandType, opOperand->get(), 725 reassociation)); 726 continue; 727 } 728 } 729 expandedOpOperands.push_back(opOperand->get()); 730 } 731 732 Location loc = genericOp.getLoc(); 733 SmallVector<Value> outputs; 734 for (OpOperand *opOperand : genericOp.getOutputOperands()) { 735 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 736 RankedTensorType expandedOutputType = 737 getExpandedType(opOperand->get().getType().cast<RankedTensorType>(), 738 indexingMap, expansionInfo); 739 if (expandedOutputType != opOperand->get().getType()) { 740 SmallVector<ReassociationIndices> reassociation = 741 getReassociationForExpansion(indexingMap, expansionInfo); 742 outputs.push_back(rewriter.create<TensorExpandShapeOp>( 743 genericOp.getLoc(), expandedOutputType, opOperand->get(), 744 reassociation)); 745 } 746 } 747 748 // The iterator types of the expanded op are all parallel. 749 SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(), 750 getParallelIteratorTypeName()); 751 752 TypeRange resultTypes = ValueRange(outputs).getTypes(); 753 auto fusedOp = 754 rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes, 755 /*inputs=*/expandedOpOperands, outputs, 756 expandedOpIndexingMaps, iteratorTypes); 757 Region &fusedRegion = fusedOp->getRegion(0); 758 Region &originalRegion = genericOp->getRegion(0); 759 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); 760 761 // Update the index accesses after the expansion. 762 updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); 763 764 // Reshape the result values to their original shape if this is a collapsing 765 // reshape folded into its consumer. 766 SmallVector<Value> resultVals; 767 for (OpResult opResult : genericOp->getOpResults()) { 768 int64_t resultNumber = opResult.getResultNumber(); 769 if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { 770 SmallVector<ReassociationIndices> reassociation = 771 getReassociationForExpansion( 772 genericOp.getTiedIndexingMap( 773 genericOp.getOutputOperand(resultNumber)), 774 expansionInfo); 775 resultVals.push_back(rewriter.create<TensorCollapseShapeOp>( 776 genericOp.getLoc(), opResult.getType(), 777 fusedOp->getResult(resultNumber), reassociation)); 778 } else { 779 resultVals.push_back(fusedOp->getResult(resultNumber)); 780 } 781 } 782 // Assuming a single result. 783 return resultVals; 784 } 785 786 namespace { 787 788 /// Pattern to fold tensor_expand_shape op with its consumer by using the source 789 /// of the reshape op as the operand in the consumer (instead of the result of 790 /// the tensor_collapse_shape). The corresponding index map in the consumer 791 /// needs to be modified to linearize the folded dimension. 792 /// 793 /// For example, 794 /// 795 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 796 /// %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2], [3]] 797 /// tensor<?x?x?xf32> into tensor<?x?x4x?xf32> 798 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } 799 /// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ... 800 /// -> tensor<?x?x4x?xf32> 801 /// 802 /// can be folded into 803 /// 804 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> 805 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 806 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } 807 /// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ... 808 /// -> tensor<?x?x4x?xf32> 809 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 810 struct FoldProducerReshapeOpByLinearization 811 : public OpRewritePattern<GenericOp> { 812 using OpRewritePattern<GenericOp>::OpRewritePattern; 813 814 LogicalResult matchAndRewrite(GenericOp genericOp, 815 PatternRewriter &rewriter) const override { 816 if (!genericOp.hasTensorSemantics()) 817 return failure(); 818 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 819 for (auto en : llvm::enumerate(inputOperands)) { 820 auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>(); 821 if (!reshapeOp) 822 continue; 823 824 if (!isTensorReshapeOpFoldableByLinearization( 825 reshapeOp, genericOp.getTiedIndexingMap(en.value()), 826 /*asProducer =*/true) || 827 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 828 continue; 829 830 // Compute the fused operands list, 831 SmallVector<Value> fusedOperands = genericOp.getInputOperands(); 832 fusedOperands[en.index()] = reshapeOp.src(); 833 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 834 llvm::append_range(fusedOperands, outputOperands); 835 836 // Compute indexing_maps for the fused operation. The indexing_maps for 837 // the operands of the consumers that arent fused are the same. 838 SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps(); 839 840 // Accepted consumer maps are either identity or permutation. 841 auto invMap = inversePermutation(fusedIndexMaps[en.index()]); 842 843 // Compute the indexing map to use for the result of the producer. 844 AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp); 845 // The modified map cannot have symbols. 846 if (modifiedMap.getNumSymbols()) 847 return failure(); 848 for (AffineExpr expr : modifiedMap.getResults()) { 849 if (!expr.isPureAffine()) 850 return failure(); 851 } 852 fusedIndexMaps[en.index()] = modifiedMap; 853 854 // Further check that the resulting index maps can be fused and 855 // inverted. Without this the resultant op is not legal. 856 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 857 return rewriter.notifyMatchFailure( 858 genericOp, "fused op loop bound computation failed"); 859 } 860 861 rewriter.startRootUpdate(genericOp); 862 genericOp->setOperands(fusedOperands); 863 genericOp.indexing_mapsAttr( 864 rewriter.getAffineMapArrayAttr(fusedIndexMaps)); 865 rewriter.finalizeRootUpdate(genericOp); 866 return success(); 867 } 868 return failure(); 869 } 870 }; 871 872 static SmallVector<ReassociationIndices> 873 getReassociationIndices(ArrayRef<AffineMap> maps) { 874 SmallVector<ReassociationIndices> reassociation; 875 for (AffineMap map : maps) { 876 ReassociationIndices indices; 877 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 878 unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition(); 879 indices.push_back(pos); 880 } 881 reassociation.push_back(indices); 882 } 883 return reassociation; 884 } 885 886 /// Pattern to move rank reducing reshape after an elementwise linalg generic 887 /// op. This is useful to expose more fusion opportunities between named ops and 888 /// generic ops. This can only be done if there is no broadcast or permuation 889 /// within the dimensions we need to merge. 890 /// 891 /// For example, 892 /// 893 /// %0 = linalg.tensor_expand_shape %A [[0, 1], [2]] 894 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 895 /// %2 = linalg.generic {indexing_maps = [ 896 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 897 /// affine_map<(d0, d1, d2) -> (d2)>, 898 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = 899 /// ["parallel", "parallel", "parallel"]} { 900 /// } -> tensor<112x112x16xf32> 901 /// 902 /// into 903 /// 904 /// %2 = linalg.generic {indexing_maps = [ 905 /// affine_map<(d0, d1) -> (d0, d1)>, 906 /// affine_map<(d0, d1) -> (d1)>, 907 /// affine_map<(d0, d1) -> (d0, d1)>], 908 /// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 909 /// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) { 910 /// } -> tensor<12544x16xf32> 911 /// %3 = linalg.tensor_expand_shape %2 [[0, 1], [2]] 912 /// : tensor<12544x16xf32> into tensor<112x112x16xf32> 913 struct PushExpandingReshape : public OpRewritePattern<GenericOp> { 914 using OpRewritePattern<GenericOp>::OpRewritePattern; 915 916 LogicalResult matchAndRewrite(GenericOp genericOp, 917 PatternRewriter &rewriter) const override { 918 // Only apply to elementwise linalg on tensor. 919 if (!genericOp.hasTensorSemantics() || 920 genericOp.getNumParallelLoops() != genericOp.getNumLoops()) 921 return failure(); 922 // Only support identity output maps. It could be extended to permuations if 923 // needed. 924 if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) { 925 return !genericOp.getTiedIndexingMap(opOperand).isIdentity(); 926 })) 927 return failure(); 928 int64_t destRank = genericOp.getNumParallelLoops(); 929 SmallVector<Value> newOperands = genericOp.getInputOperands(); 930 TensorExpandShapeOp reshapeFound; 931 // 1. Look for tensor_expand_shape operands and figure out save the 932 // dimensions merged. 933 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 934 for (auto en : llvm::enumerate(inputOperands)) { 935 auto reshapeOp = 936 en.value()->get().template getDefiningOp<TensorExpandShapeOp>(); 937 if (!reshapeOp) 938 continue; 939 // TODO: We could support non-identity map as long as the merged 940 // dimensions are still contiguous. 941 if (!genericOp.getTiedIndexingMap(en.value()).isIdentity()) 942 continue; 943 if (reshapeFound) { 944 // Only support a second reshape op if it has the same reassociate maps. 945 if (reshapeFound.getReassociationMaps() == 946 reshapeOp.getReassociationMaps()) 947 newOperands[en.index()] = reshapeOp.src(); 948 continue; 949 } 950 reshapeFound = reshapeOp; 951 newOperands[en.index()] = reshapeOp.src(); 952 } 953 if (!reshapeFound) 954 return failure(); 955 956 // Calculate the reassociation indices and rassociated reverse map. 957 SmallVector<ReassociationIndices> reassociation = 958 getReassociationIndices(reshapeFound.getReassociationMaps()); 959 SmallVector<unsigned> remap(destRank); 960 for (auto &indices : llvm::enumerate(reassociation)) { 961 for (int64_t index : indices.value()) { 962 remap[index] = indices.index(); 963 } 964 } 965 // 2. Verify that we can merge the dimensions in the linalg and that we 966 // don't need to create new reshapes operands. Inserting new reshape 967 // operands would defeat the purpose of the transformation. 968 for (auto en : llvm::enumerate(inputOperands)) { 969 if (en.value()->get() == newOperands[en.index()]) { 970 AffineMap map = genericOp.getTiedIndexingMap(en.value()); 971 for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { 972 if (reassociation[remap[map.getDimPosition(i)]].size() > 1) 973 return failure(); 974 } 975 } 976 } 977 978 // 3. Calculate the affine map remapping and the reassociation to apply to 979 // output tensors. 980 SmallVector<AffineMap> newMaps; 981 unsigned newRank = reassociation.size(); 982 for (auto map : genericOp.getIndexingMaps()) { 983 SmallVector<AffineExpr> newExprs; 984 for (auto expr : map.getResults()) { 985 unsigned position = expr.template cast<AffineDimExpr>().getPosition(); 986 // Skip dimension merged except for the last of the group. 987 if (reassociation[remap[position]].back() == position) { 988 newExprs.push_back( 989 getAffineDimExpr(remap[position], genericOp.getContext())); 990 } 991 } 992 newMaps.push_back( 993 AffineMap::get(newRank, 0, newExprs, genericOp.getContext())); 994 } 995 996 // 4. Reshape the output tensors. 997 SmallVector<Value> newOutputs; 998 SmallVector<Type> newOutputTypes; 999 for (auto output : genericOp.outputs()) { 1000 auto newOutputType = RankedTensorType::get( 1001 reshapeFound.getSrcType().getShape(), 1002 output.getType().template cast<RankedTensorType>().getElementType()); 1003 Value newOutput = rewriter.create<TensorCollapseShapeOp>( 1004 genericOp->getLoc(), newOutputType, output, reassociation); 1005 newOutputTypes.push_back(newOutputType); 1006 newOutputs.push_back(newOutput); 1007 } 1008 // 5. Create a new generic op with lowerer rank. 1009 SmallVector<StringRef> iteratorTypes(newRank, 1010 getParallelIteratorTypeName()); 1011 auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes, 1012 newOperands, newOutputs, newMaps, 1013 iteratorTypes); 1014 rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), 1015 newOp.region().begin()); 1016 // 6. Reshape the so that the type matches the uses. 1017 SmallVector<Value> newResults; 1018 for (auto result : llvm::enumerate(newOp->getResults())) { 1019 newResults.push_back(rewriter.create<TensorExpandShapeOp>( 1020 genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()], 1021 result.value(), reassociation)); 1022 } 1023 rewriter.replaceOp(genericOp, newResults); 1024 return success(); 1025 } 1026 }; 1027 1028 /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op, 1029 /// when the reshape op is collapsing dimensions. The dimensionality of the loop 1030 /// in the consumer is expanded. 1031 class FoldWithProducerReshapeOpByExpansion 1032 : public OpRewritePattern<GenericOp> { 1033 public: 1034 FoldWithProducerReshapeOpByExpansion( 1035 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1036 PatternBenefit benefit = 1) 1037 : OpRewritePattern<GenericOp>(context, benefit), 1038 controlFoldingReshapes(foldReshapes) {} 1039 1040 LogicalResult matchAndRewrite(GenericOp genericOp, 1041 PatternRewriter &rewriter) const override { 1042 for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { 1043 TensorCollapseShapeOp reshapeOp = 1044 opOperand->get().getDefiningOp<TensorCollapseShapeOp>(); 1045 if (!reshapeOp) 1046 continue; 1047 // Fold only if 1048 // - The tensor reshape op is folding. 1049 // - All constraints of fusing with reshape by expansion are met. 1050 if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || 1051 (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) 1052 continue; 1053 1054 Optional<SmallVector<Value>> replacementValues = 1055 fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); 1056 if (!replacementValues) 1057 return failure(); 1058 rewriter.replaceOp(genericOp, replacementValues.getValue()); 1059 return success(); 1060 } 1061 return failure(); 1062 } 1063 1064 private: 1065 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1066 }; 1067 1068 /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its 1069 /// producer. The corresponding index map in the consumer needs to be modified 1070 /// to linearize the folded dimension. 1071 template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp> 1072 struct FoldConsumerReshapeOpByLinearization 1073 : public OpRewritePattern<TensorReshapeOp> { 1074 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 1075 1076 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 1077 PatternRewriter &rewriter) const override { 1078 GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>(); 1079 if (!producer || !producer.hasTensorSemantics() || 1080 producer.getNumOutputs() != 1 || 1081 !isTensorReshapeOpFoldableByLinearization( 1082 reshapeOp, 1083 producer.getTiedIndexingMap(producer.getOutputOperand(0)), 1084 /*asProducer =*/false) || 1085 (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) 1086 return failure(); 1087 // The indexing_maps for the operands of the fused operation are same as 1088 // those for the operands of the producer. 1089 SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps(); 1090 1091 auto invMap = inversePermutation( 1092 producer.getTiedIndexingMap(producer.getOutputOperand(0))); 1093 1094 // Compute the indexing map to use for the operand of the producer. 1095 AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp); 1096 for (AffineExpr expr : modifiedMap.getResults()) { 1097 if (!expr.isPureAffine()) { 1098 return rewriter.notifyMatchFailure( 1099 producer, "fused op indexing map is not affine"); 1100 } 1101 } 1102 fusedIndexMaps.back() = modifiedMap; 1103 1104 // Further check that the resulting index maps can be fused and 1105 // inverted. Without this the resultant op is not legal. 1106 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1107 return rewriter.notifyMatchFailure( 1108 producer, "fused op loop bound computation failed"); 1109 } 1110 1111 Location loc = producer.getLoc(); 1112 SmallVector<Value> inputOperands = producer.getInputOperands(); 1113 Value output = rewriter.create<TensorReshapeOp>( 1114 loc, producer.getOutputOperand(0)->get(), 1115 reshapeOp.getReassociationExprs()); 1116 auto fusedOp = rewriter.create<GenericOp>( 1117 loc, reshapeOp.getResultType(), 1118 /*inputs=*/inputOperands, 1119 // TODO: handle outputs. 1120 /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1121 producer.iterator_types(), 1122 /*doc=*/nullptr, 1123 /*library_call=*/nullptr); 1124 auto &fusedRegion = fusedOp->getRegion(0); 1125 rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion, 1126 fusedRegion.begin()); 1127 rewriter.replaceOp(reshapeOp, fusedOp->getResults()); 1128 return success(); 1129 } 1130 }; 1131 1132 /// Pattern to fold a tensor_expand_shape op with its producer generic op 1133 /// by expanding the dimensionality of the loop in the producer op. 1134 struct FoldReshapeWithGenericOpByExpansion 1135 : public OpRewritePattern<TensorExpandShapeOp> { 1136 1137 FoldReshapeWithGenericOpByExpansion( 1138 MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, 1139 PatternBenefit benefit = 1) 1140 : OpRewritePattern<TensorExpandShapeOp>(context, benefit), 1141 controlFoldingReshapes(foldReshapes) {} 1142 1143 LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp, 1144 PatternRewriter &rewriter) const override { 1145 // Fold only if all constraints of fusing with reshape by expansion are met. 1146 GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>(); 1147 if (!producer || producer.getNumOutputs() != 1 || 1148 !isFusableWithReshapeByDimExpansion(producer, 1149 producer.getOutputOperand(0)) || 1150 !controlFoldingReshapes(producer->getResult(0), 1151 reshapeOp->getOpOperand(0))) 1152 return failure(); 1153 Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion( 1154 producer, reshapeOp, producer.getOutputOperand(0), rewriter); 1155 if (!replacementValues) 1156 return failure(); 1157 rewriter.replaceOp(reshapeOp, replacementValues.getValue()); 1158 return success(); 1159 } 1160 1161 private: 1162 ControlElementwiseOpsFusionFn controlFoldingReshapes; 1163 }; 1164 1165 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not 1166 /// handle cases where the constant is not single-valued. 1167 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> { 1168 public: 1169 FoldScalarOrSplatConstant(MLIRContext *context, 1170 ControlElementwiseOpsFusionFn &fun, 1171 PatternBenefit benefit = 1) 1172 : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {} 1173 1174 LogicalResult matchAndRewrite(GenericOp genericOp, 1175 PatternRewriter &rewriter) const override { 1176 if (!genericOp.hasTensorSemantics()) 1177 return failure(); 1178 for (OpOperand *opOperand : genericOp.getInputOperands()) { 1179 Operation *def = opOperand->get().getDefiningOp(); 1180 Attribute constantAttr; 1181 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { 1182 { 1183 DenseElementsAttr splatAttr; 1184 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) && 1185 splatAttr.isSplat() && 1186 splatAttr.getType().getElementType().isIntOrFloat()) { 1187 constantAttr = splatAttr.getSplatValue<Attribute>(); 1188 return true; 1189 } 1190 } 1191 { 1192 IntegerAttr intAttr; 1193 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) { 1194 constantAttr = intAttr; 1195 return true; 1196 } 1197 } 1198 { 1199 FloatAttr floatAttr; 1200 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) { 1201 constantAttr = floatAttr; 1202 return true; 1203 } 1204 } 1205 return false; 1206 }; 1207 1208 auto resultValue = opOperand->get().dyn_cast<OpResult>(); 1209 if (!def || !resultValue || !isScalarOrSplatConstantOp(def) || 1210 !controlFn(resultValue, *opOperand)) 1211 continue; 1212 1213 // The operands and the indexing_maps of the fused operation the same as 1214 // the operands and indexing_maps of the generic operations with the 1215 // values at the constant index dropped. 1216 SmallVector<AffineMap> fusedIndexMaps; 1217 SmallVector<Value> fusedOperands; 1218 SmallVector<Location> fusedLocs{genericOp.getLoc()}; 1219 fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); 1220 fusedOperands.reserve(genericOp.getNumInputs()); 1221 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); 1222 for (OpOperand *inputOperand : genericOp.getInputOperands()) { 1223 if (inputOperand == opOperand) 1224 continue; 1225 Value inputValue = inputOperand->get(); 1226 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); 1227 fusedOperands.push_back(inputValue); 1228 fusedLocs.push_back(inputValue.getLoc()); 1229 } 1230 for (OpOperand *outputOperand : genericOp.getOutputOperands()) 1231 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); 1232 1233 // Check if the operation shapes to loops map is computable. 1234 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { 1235 return rewriter.notifyMatchFailure( 1236 genericOp, "fused op loop bound computation failed"); 1237 } 1238 1239 // Create a constant scalar value from the splat constant. 1240 Value scalarConstant = rewriter.create<arith::ConstantOp>( 1241 def->getLoc(), constantAttr, constantAttr.getType()); 1242 1243 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 1244 auto fusedOp = rewriter.create<GenericOp>( 1245 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), 1246 /*inputs=*/fusedOperands, 1247 /*outputs=*/outputOperands, 1248 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1249 genericOp.iterator_types(), 1250 /*doc=*/nullptr, 1251 /*library_call=*/nullptr); 1252 1253 // Map the block argument corresponding to the replaced argument with the 1254 // scalar constant. 1255 Region ®ion = genericOp->getRegion(0); 1256 Block &entryBlock = *region.begin(); 1257 BlockAndValueMapping mapping; 1258 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), 1259 scalarConstant); 1260 Region &fusedRegion = fusedOp->getRegion(0); 1261 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), 1262 mapping); 1263 rewriter.replaceOp(genericOp, fusedOp->getResults()); 1264 return success(); 1265 } 1266 return failure(); 1267 } 1268 1269 private: 1270 ControlElementwiseOpsFusionFn controlFn; 1271 }; 1272 1273 /// Base class for constant folding linalg.generic ops with N inputs, 1 output, 1274 /// and permutation indexing maps. 1275 /// 1276 /// `ConcreteType` should provide methods with signatures 1277 /// 1278 /// ```c++ 1279 /// bool matchIndexingMaps(GenericOp genericOp) const; 1280 /// RegionComputationFn getRegionComputeFn(GenericOp) const; 1281 /// ``` 1282 /// 1283 /// The latter inspects the region and returns the computation inside as a 1284 /// functor. The functor will be invoked with constant elements for all inputs 1285 /// and should return the corresponding computea constant element for output. 1286 template <typename ConcreteType> 1287 class FoldConstantBase : public OpRewritePattern<GenericOp> { 1288 public: 1289 struct APIntOrFloat { 1290 Optional<APInt> apInt; 1291 Optional<APFloat> apFloat; 1292 }; 1293 struct APIntOrFloatArray { 1294 SmallVector<APInt> apInts; 1295 SmallVector<APFloat> apFloats; 1296 }; 1297 using RegionComputationFn = 1298 std::function<APIntOrFloat(const APIntOrFloatArray &)>; 1299 1300 FoldConstantBase(MLIRContext *context, 1301 const ControlElementwiseOpsFusionFn &controlFn, 1302 PatternBenefit benefit = 1) 1303 : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {} 1304 1305 LogicalResult matchAndRewrite(GenericOp genericOp, 1306 PatternRewriter &rewriter) const override { 1307 if (genericOp.hasBufferSemantics()) 1308 return failure(); 1309 1310 // Only support ops generating one output for now. 1311 if (genericOp.getNumOutputs() != 1) 1312 return failure(); 1313 1314 auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>(); 1315 // Require the output types to be static give we are generating constants. 1316 if (!outputType || !outputType.hasStaticShape()) 1317 return failure(); 1318 1319 if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) { 1320 return operand->get().getType().isa<ShapedType>(); 1321 })) 1322 return failure(); 1323 1324 // Make sure all element types are the same. 1325 auto getOperandElementType = [](OpOperand *operand) { 1326 return operand->get().getType().cast<ShapedType>().getElementType(); 1327 }; 1328 if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(), 1329 getOperandElementType))) 1330 return failure(); 1331 1332 // We can only handle the case where we have int/float elements. 1333 auto elementType = outputType.getElementType(); 1334 if (!elementType.isIntOrFloat()) 1335 return failure(); 1336 1337 // Require all indexing maps to be permutations for now. This is common and 1338 // it simplifies input/output access greatly: we can do the data shuffling 1339 // entirely in the compiler, without needing to turn all indices into 1340 // Values, and then do affine apply on them, and then match back the 1341 // constant again. 1342 if (!llvm::all_of(genericOp.getIndexingMaps(), 1343 [](AffineMap map) { return map.isPermutation(); })) 1344 return failure(); 1345 1346 for (OpOperand *operand : genericOp.getOutputOperands()) { 1347 if (genericOp.payloadUsesValueFromOperand(operand)) 1348 return failure(); 1349 } 1350 1351 // Further check the indexing maps are okay for the ConcreteType. 1352 if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp)) 1353 return failure(); 1354 1355 // Defer to the concrete type to check the region and discover the 1356 // computation inside. 1357 RegionComputationFn computeFn = 1358 static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp); 1359 if (!computeFn) 1360 return failure(); 1361 1362 // All inputs should be constants. 1363 int numInputs = genericOp.getNumInputs(); 1364 SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs); 1365 for (auto operand : llvm::enumerate(genericOp.getInputOperands())) { 1366 if (!matchPattern(operand.value()->get(), 1367 m_Constant(&inputValues[operand.index()]))) 1368 return failure(); 1369 } 1370 1371 // Identified this as a potential candidate for folding. Now check the 1372 // policy to see whether we are allowed to proceed. 1373 for (int i = 0; i < numInputs; ++i) { 1374 OpOperand *consumer = genericOp.getInputOperand(i); 1375 OpResult producer = consumer->get().cast<OpResult>(); 1376 if (!controlFn(producer, *consumer)) 1377 return failure(); 1378 } 1379 1380 auto linalgOp = cast<LinalgOp>(genericOp.getOperation()); 1381 SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes(); 1382 int64_t numElements = outputType.getNumElements(); 1383 1384 // Use APInt/APFloat instead of Attribute here for constructing the output. 1385 // This helps to avoid blowing up compiler memory usage: Attributes would 1386 // unify the following cases but they have lifetime as the MLIRContext. 1387 SmallVector<APInt> intOutputValues; 1388 SmallVector<APFloat> fpOutputValues; 1389 if (elementType.template isa<FloatType>()) 1390 fpOutputValues.resize(numElements, APFloat(0.f)); 1391 else 1392 intOutputValues.resize(numElements); 1393 1394 // Return the constant dim positions from the given permutation map. 1395 auto getDimPositions = [](AffineMap map) { 1396 SmallVector<unsigned> dims; 1397 dims.reserve(map.getNumResults()); 1398 for (AffineExpr result : map.getResults()) { 1399 dims.push_back(result.cast<AffineDimExpr>().getPosition()); 1400 } 1401 return dims; 1402 }; 1403 1404 SmallVector<SmallVector<unsigned>> inputDims; 1405 for (int i = 0; i < numInputs; ++i) 1406 inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i])); 1407 auto outputDims = getDimPositions(genericOp.getIndexingMaps().back()); 1408 auto outputShape = outputType.getShape(); 1409 1410 // Allocate small vectors for index delinearization. Initial values do not 1411 // matter here as they will be overwritten later. 1412 SmallVector<uint64_t> indices(loopBounds.size(), 0); 1413 SmallVector<uint64_t> dstIndices(loopBounds.size(), 0); 1414 SmallVector<SmallVector<uint64_t>> srcIndices( 1415 numInputs, SmallVector<uint64_t>(loopBounds.size(), 0)); 1416 SmallVector<uint64_t> srcLinearIndices(numInputs, 0); 1417 uint64_t dstLinearIndex = 0; 1418 1419 // Allocate spaces for compute function inputs. Initial values do not matter 1420 // here as they will be overwritten later. 1421 APIntOrFloatArray computeFnInputs; 1422 1423 auto inputShapes = llvm::to_vector<4>( 1424 llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { 1425 return operand->get().getType().cast<ShapedType>().getShape(); 1426 })); 1427 1428 // Given a `linearIndex`, remap it to a linear index to access linalg op 1429 // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, 1430 // `srcLinearIndices`, `dstLinearIndex` in place. 1431 auto computeRemappedLinearIndex = [&](int linearIndex) { 1432 int totalCount = linearIndex; 1433 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 1434 indices[dim] = totalCount % loopBounds[dim]; 1435 totalCount /= loopBounds[dim]; 1436 } 1437 1438 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 1439 for (int i = 0; i < numInputs; ++i) 1440 srcIndices[i][dim] = indices[inputDims[i][dim]]; 1441 dstIndices[dim] = indices[outputDims[dim]]; 1442 } 1443 1444 dstLinearIndex = dstIndices.front(); 1445 for (int i = 0; i < numInputs; ++i) 1446 srcLinearIndices[i] = srcIndices[i].front(); 1447 1448 for (int dim = 1; dim < outputType.getRank(); ++dim) { 1449 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; 1450 for (int i = 0; i < numInputs; ++i) 1451 srcLinearIndices[i] = 1452 srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; 1453 } 1454 }; 1455 1456 bool isFloat = elementType.isa<FloatType>(); 1457 if (isFloat) { 1458 SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges; 1459 for (int i = 0; i < numInputs; ++i) 1460 inFpRanges.push_back(inputValues[i].getValues<APFloat>()); 1461 1462 computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); 1463 1464 // Transpose the input constant. Because we don't know its rank in 1465 // advance, we need to loop over the range [0, element count) and 1466 // delinearize the index. 1467 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 1468 computeRemappedLinearIndex(linearIndex); 1469 1470 // Collect constant elements for all inputs at this loop iteration. 1471 for (int i = 0; i < numInputs; ++i) 1472 computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; 1473 1474 // Invoke the computation to get the corresponding constant output 1475 // element. 1476 fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; 1477 } 1478 } else { 1479 SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges; 1480 for (int i = 0; i < numInputs; ++i) 1481 inIntRanges.push_back(inputValues[i].getValues<APInt>()); 1482 1483 computeFnInputs.apInts.resize(numInputs); 1484 1485 // Transpose the input constant. Because we don't know its rank in 1486 // advance, we need to loop over the range [0, element count) and 1487 // delinearize the index. 1488 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 1489 computeRemappedLinearIndex(linearIndex); 1490 1491 // Collect constant elements for all inputs at this loop iteration. 1492 for (int i = 0; i < numInputs; ++i) 1493 computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; 1494 1495 // Invoke the computation to get the corresponding constant output 1496 // element. 1497 intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; 1498 } 1499 } 1500 1501 DenseElementsAttr outputAttr = 1502 isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) 1503 : DenseElementsAttr::get(outputType, intOutputValues); 1504 1505 rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr); 1506 return success(); 1507 } 1508 1509 private: 1510 ControlElementwiseOpsFusionFn controlFn; 1511 }; 1512 1513 // Folds linalg.generic ops that are actually transposes on constant values. 1514 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> { 1515 using FoldConstantBase::FoldConstantBase; 1516 1517 bool matchIndexingMaps(GenericOp genericOp) const { 1518 // We should have one input and one output. 1519 return genericOp.getIndexingMaps().size() == 2; 1520 } 1521 1522 RegionComputationFn getRegionComputeFn(GenericOp genericOp) const { 1523 // Make sure the region only contains a yield op. 1524 Block &body = genericOp.region().front(); 1525 if (!llvm::hasSingleElement(body)) 1526 return nullptr; 1527 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); 1528 if (!yieldOp) 1529 return nullptr; 1530 1531 // The yield op should return the block argument corresponds to the input. 1532 for (Value yieldVal : yieldOp.values()) { 1533 auto yieldArg = yieldVal.dyn_cast<BlockArgument>(); 1534 if (!yieldArg || yieldArg.getOwner() != &body) 1535 return nullptr; 1536 if (yieldArg.getArgNumber() != 0) 1537 return nullptr; 1538 } 1539 1540 // No computation; just return the orginal value. 1541 return [](const APIntOrFloatArray &inputs) { 1542 if (inputs.apFloats.empty()) 1543 return APIntOrFloat{inputs.apInts.front(), llvm::None}; 1544 return APIntOrFloat{llvm::None, inputs.apFloats.front()}; 1545 }; 1546 } 1547 1548 ControlElementwiseOpsFusionFn controlFn; 1549 }; 1550 1551 } // namespace 1552 1553 static Optional<SmallVector<Value>> 1554 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, 1555 GenericOp producer, 1556 const ControlElementwiseOpsFusionFn &controlFn) { 1557 if (producer->getNumResults() != 1) 1558 return llvm::None; 1559 1560 return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, 1561 rewriter); 1562 } 1563 1564 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, 1565 OpOperand &consumer) { 1566 if (auto producerCollapseOp = 1567 dyn_cast<linalg::TensorCollapseShapeOp>(producer.getOwner())) { 1568 return !isUnitDimExpansionOnly(producerCollapseOp); 1569 } 1570 if (auto consumerExpandOp = 1571 dyn_cast<linalg::TensorExpandShapeOp>(consumer.getOwner())) { 1572 return !isUnitDimExpansionOnly(consumerExpandOp); 1573 } 1574 return true; 1575 } 1576 1577 namespace { 1578 /// Patterns to fuse a generic op, with the producer of its operands. 1579 class FuseElementwiseOps : public OpRewritePattern<GenericOp> { 1580 public: 1581 FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, 1582 PatternBenefit benefit = 1) 1583 : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {} 1584 1585 LogicalResult matchAndRewrite(GenericOp genericOp, 1586 PatternRewriter &rewriter) const override { 1587 // Find the first operand that is defined by another generic op on tensors. 1588 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { 1589 auto producer = 1590 dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp()); 1591 if (!producer || !producer.hasTensorSemantics()) 1592 continue; 1593 Optional<SmallVector<Value>> fusedOpResults = 1594 fuseElementwiseOps(rewriter, opOperand, producer, controlFn); 1595 if (fusedOpResults) { 1596 rewriter.replaceOp(genericOp, *fusedOpResults); 1597 return success(); 1598 } 1599 } 1600 return failure(); 1601 } 1602 1603 private: 1604 ControlElementwiseOpsFusionFn controlFn; 1605 }; 1606 1607 /// Pass that fuses generic ops on tensors. Used only for testing. 1608 struct LinalgElementwiseOpFusionPass 1609 : public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> { 1610 void runOnOperation() override { 1611 Operation *op = getOperation(); 1612 RewritePatternSet patterns(op->getContext()); 1613 ControlElementwiseOpsFusionFn allowFoldingFn = 1614 [](const OpResult &producer, const OpOperand &consumer) { 1615 return true; 1616 }; 1617 populateElementwiseOpsFusionPatterns( 1618 patterns, 1619 LinalgElementwiseFusionOptions().setControlFoldingReshapes( 1620 allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape)); 1621 1622 // Use TopDownTraversal for compile time reasons 1623 GreedyRewriteConfig grc; 1624 grc.useTopDownTraversal = true; 1625 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), 1626 grc); 1627 } 1628 }; 1629 1630 /// Pass to test folding of reshape ops with generic ops by linearization. 1631 struct FoldReshapeOpsByLinearizationPass 1632 : public LinalgFoldReshapeOpsByLinearizationBase< 1633 FoldReshapeOpsByLinearizationPass> { 1634 void runOnOperation() override { 1635 Operation *op = getOperation(); 1636 RewritePatternSet patterns(op->getContext()); 1637 populateFoldReshapeOpsByLinearizationPatterns(patterns); 1638 if (allowFoldingUnitDimReshapes) { 1639 populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns); 1640 } 1641 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); 1642 } 1643 }; 1644 1645 /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if 1646 /// the value of the `outs` operand is not used within the op. This is only 1647 /// implemented for `linalg.generic` operations for now, but should hold for all 1648 /// linalg structured ops. 1649 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> { 1650 using OpRewritePattern<GenericOp>::OpRewritePattern; 1651 1652 LogicalResult matchAndRewrite(GenericOp op, 1653 PatternRewriter &rewriter) const override { 1654 rewriter.startRootUpdate(op); 1655 bool modifiedOutput = false; 1656 Location loc = op.getLoc(); 1657 for (OpOperand *opOperand : op.getOutputOperands()) { 1658 if (!op.payloadUsesValueFromOperand(opOperand)) { 1659 Value operandVal = opOperand->get(); 1660 auto operandType = operandVal.getType().dyn_cast<RankedTensorType>(); 1661 if (!operandType) 1662 continue; 1663 1664 // If outs is already an `init_tensor` operation, nothing to do. 1665 auto definingOp = operandVal.getDefiningOp<InitTensorOp>(); 1666 if (definingOp) 1667 continue; 1668 modifiedOutput = true; 1669 SmallVector<Value> dynamicDims; 1670 for (auto dim : llvm::enumerate(operandType.getShape())) { 1671 if (dim.value() != ShapedType::kDynamicSize) 1672 continue; 1673 dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>( 1674 loc, operandVal, dim.index())); 1675 } 1676 Value initTensor = rewriter.create<InitTensorOp>( 1677 loc, dynamicDims, operandType.getShape(), 1678 operandType.getElementType()); 1679 op->setOperand(opOperand->getOperandNumber(), initTensor); 1680 } 1681 } 1682 if (!modifiedOutput) { 1683 rewriter.cancelRootUpdate(op); 1684 return failure(); 1685 } 1686 rewriter.finalizeRootUpdate(op); 1687 return success(); 1688 } 1689 }; 1690 1691 } // namespace 1692 1693 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( 1694 RewritePatternSet &patterns) { 1695 patterns 1696 .add<FoldProducerReshapeOpByLinearization<false, TensorCollapseShapeOp>, 1697 FoldProducerReshapeOpByLinearization<false, TensorExpandShapeOp>, 1698 FoldConsumerReshapeOpByLinearization<false, TensorCollapseShapeOp>, 1699 FoldConsumerReshapeOpByLinearization<false, TensorExpandShapeOp>>( 1700 patterns.getContext()); 1701 } 1702 1703 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( 1704 RewritePatternSet &patterns) { 1705 patterns 1706 .add<FoldProducerReshapeOpByLinearization<true, TensorCollapseShapeOp>, 1707 FoldProducerReshapeOpByLinearization<true, TensorExpandShapeOp>, 1708 FoldConsumerReshapeOpByLinearization<true, TensorCollapseShapeOp>, 1709 FoldConsumerReshapeOpByLinearization<true, TensorExpandShapeOp>>( 1710 patterns.getContext()); 1711 } 1712 1713 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( 1714 RewritePatternSet &patterns, 1715 ControlElementwiseOpsFusionFn controlFoldingReshapes) { 1716 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(), 1717 controlFoldingReshapes); 1718 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), 1719 controlFoldingReshapes); 1720 } 1721 1722 void mlir::linalg::populateElementwiseOpsFusionPatterns( 1723 RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { 1724 auto *context = patterns.getContext(); 1725 patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant, 1726 FoldConstantTranspose>(context, 1727 options.controlElementwiseOpsFusionFn); 1728 patterns.add<RemoveOutsDependency>(context); 1729 populateFoldReshapeOpsByExpansionPatterns(patterns, 1730 options.controlFoldingReshapesFn); 1731 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 1732 GenericOp::getCanonicalizationPatterns(patterns, context); 1733 TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context); 1734 TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context); 1735 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns( 1736 patterns); 1737 } 1738 1739 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { 1740 auto *context = patterns.getContext(); 1741 patterns.add<PushExpandingReshape>(context); 1742 } 1743 1744 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() { 1745 return std::make_unique<LinalgElementwiseOpFusionPass>(); 1746 } 1747 1748 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() { 1749 return std::make_unique<FoldReshapeOpsByLinearizationPass>(); 1750 } 1751