1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion on tensors operations pass. 10 // 11 //===----------------------------------------------------------------------===// 12 #include "PassDetail.h" 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/IR/AffineExpr.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Support/LLVM.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 28 /// Implementation of fusion of generic ops and indexed_generic ops. 29 static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer, 30 unsigned consumerIdx) { 31 // Producer and consumer must have tensor semantics. 32 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 33 return false; 34 35 // Verify that 36 // - the producer has all "parallel" iterator type. 37 if (producer.getNumParallelLoops() != producer.getNumLoops()) 38 return false; 39 40 // Get the consumer index map. The number of results of the consumer index 41 // map must match the number of loops of the producer. 42 AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx); 43 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 44 return false; 45 46 // Finally the index_map for the result must be invertible. For now just 47 // verify it is a permutation. 48 AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); 49 return producerResultIndexMap.isPermutation(); 50 } 51 52 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 53 /// the `producer` to use in the fused operation given the indexing map of the 54 /// result of the producer in the consumer. 55 static void getIndexingMapOfProducerOperandsInFusedOp( 56 LinalgOp producer, AffineMap fusedConsumerArgIndexMap, 57 SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) { 58 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 59 // from consumer loop -> consumer arg tensor index/producer result tensor 60 // index. The fused loop is same as the consumer loop. For each producer arg 61 // the indexing map to be computed is a map from consumer loop -> producer 62 // arg tensor index. 63 64 AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); 65 // producerResultIndexMap is a map from producer loop -> tensor index. 66 // Compute the inverse to get map from tensor index -> producer loop. 67 // The inverse is a map from producer result tensor index -> producer loop. 68 AffineMap invProducerResultIndexMap = 69 inversePermutation(producerResultIndexMap); 70 assert(invProducerResultIndexMap && 71 "expected producer result indexig map to be invertible"); 72 for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) { 73 // argMap is a map from producer loop -> producer arg tensor index. 74 AffineMap argMap = producer.getInputIndexingMap(argNum); 75 76 // Compose argMap with invProducerResultIndexMap to get a map from 77 // producer result tensor index -> producer arg tensor index. 78 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 79 80 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 81 // consumer loop/ fused loop -> producer arg tensor index. 82 AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap); 83 fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap)); 84 } 85 } 86 87 /// Generate the region of the fused tensor operation. The region of the fused 88 /// op must be empty. 89 static void generateFusedTensorOpRegion(PatternRewriter &rewriter, 90 Operation *fusedOp, LinalgOp producer, 91 LinalgOp consumer, 92 AffineMap consumerToProducerLoopsMap, 93 unsigned consumerIdx, unsigned nloops) { 94 // Build the region of the fused op. 95 Block &producerBlock = producer->getRegion(0).front(); 96 Block &consumerBlock = consumer->getRegion(0).front(); 97 Block *fusedBlock = new Block(); 98 fusedOp->getRegion(0).push_back(fusedBlock); 99 BlockAndValueMapping mapper; 100 OpBuilder::InsertionGuard guard(rewriter); 101 rewriter.setInsertionPointToStart(fusedBlock); 102 103 // The block arguments are 104 // [index_0, index_1, ... , 105 // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1), 106 // producer_operand_0, ... , producer_operand_(n-1)], 107 // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)] 108 // , where n is the number of producer's operand and m is the number 109 // consumer's operand. 110 // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a 111 // generic op. In this case, there are no indices in block arguments. 112 unsigned numProducerIndices = isa<IndexedGenericOp>(producer.getOperation()) 113 ? producer.getNumLoops() 114 : 0; 115 unsigned numConsumerIndices = isa<IndexedGenericOp>(consumer.getOperation()) 116 ? consumer.getNumLoops() 117 : 0; 118 unsigned numFusedOpIndices = 119 (isa<IndexedGenericOp>(producer.getOperation()) || 120 isa<IndexedGenericOp>(consumer.getOperation())) 121 ? std::max(producer.getNumLoops(), consumer.getNumLoops()) 122 : 0; 123 // Firstly, add all the indices to the block arguments. 124 for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i) 125 fusedBlock->addArgument(rewriter.getIndexType()); 126 // Map the arguments for the unmodified args from the consumer. 127 for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) { 128 if (consumerArg.index() == consumerIdx + numConsumerIndices) { 129 // Map the arguments for the args from the producer. 130 for (auto producerArg : 131 llvm::enumerate(producerBlock.getArguments().take_front( 132 producer.getNumInputs() + numProducerIndices))) { 133 // If producer is an indexed_generic op, map the indices from consumer 134 // loop to producer loop (because the fusedOp is built based on 135 // consumer's perspective). 136 if (producerArg.index() < numProducerIndices) { 137 auto newIndex = rewriter.create<mlir::AffineApplyOp>( 138 producer.getLoc(), 139 consumerToProducerLoopsMap.getSubMap(producerArg.index()), 140 fusedBlock->getArguments().take_front(numFusedOpIndices)); 141 mapper.map(producerArg.value(), newIndex); 142 } else { 143 mapper.map(producerArg.value(), 144 fusedBlock->addArgument(producerArg.value().getType())); 145 } 146 } 147 continue; 148 } 149 150 // If consumer is an indexed_generic op, map the indices to the block 151 // arguments directly. Otherwise, add the same type of argument and map to 152 // it. 153 if (consumerArg.index() < numConsumerIndices) { 154 mapper.map(consumerArg.value(), 155 fusedBlock->getArgument(consumerArg.index())); 156 } else { 157 mapper.map(consumerArg.value(), 158 fusedBlock->addArgument(consumerArg.value().getType())); 159 } 160 } 161 162 // Add operations from producer (except the yield operation) to the fused 163 // op. 164 for (auto &op : producerBlock.getOperations()) { 165 if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) { 166 // Lookup the value the yield operation is mapped to. 167 Value yieldVal = yieldOp.getOperand(0); 168 if (Value clonedVal = mapper.lookupOrNull(yieldVal)) 169 mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices), 170 clonedVal); 171 continue; 172 } 173 rewriter.clone(op, mapper); 174 } 175 for (auto &op : consumerBlock.getOperations()) 176 rewriter.clone(op, mapper); 177 } 178 179 static Optional<SmallVector<Value, 1>> 180 fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand, 181 PatternRewriter &rewriter) { 182 LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner()); 183 unsigned consumerIdx = consumerOpOperand.getOperandNumber(); 184 if (!areTensorOpsFusable(producer, consumer, consumerIdx)) 185 return llvm::None; 186 187 unsigned numFusedOperands = 188 producer.getNumInputs() + consumer.getNumInputs() - 1; 189 190 // Compute the fused operands list, 191 SmallVector<Value, 2> fusedOperands; 192 fusedOperands.reserve(numFusedOperands); 193 auto consumerOperands = consumer.getInputs(); 194 auto producerOperands = producer.getInputs(); 195 fusedOperands.assign(consumerOperands.begin(), 196 std::next(consumerOperands.begin(), consumerIdx)); 197 fusedOperands.append(producerOperands.begin(), producerOperands.end()); 198 fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1), 199 consumerOperands.end()); 200 201 // Compute indexing_maps for the fused operation. The indexing_maps for the 202 // operands of the consumers that aren't fused are the same. The 203 // indexing_maps for the producers need to be computed based on the 204 // indexing_map of the operand at consumerIdx in the consumer. 205 SmallVector<Attribute, 4> fusedIndexMaps; 206 auto consumerIndexMaps = consumer.indexing_maps(); 207 fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumOutputs()); 208 fusedIndexMaps.assign(consumerIndexMaps.begin(), 209 std::next(consumerIndexMaps.begin(), consumerIdx)); 210 // Compute indexing maps for the producer args in the fused operation. 211 getIndexingMapOfProducerOperandsInFusedOp( 212 producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps); 213 214 // Append the indexing maps for the remaining consumer operands. 215 fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1), 216 consumerIndexMaps.end()); 217 218 // Generate the fused op. 219 LinalgOp fusedOp; 220 if (isa<GenericOp>(producer.getOperation()) && 221 isa<GenericOp>(consumer.getOperation())) { 222 fusedOp = 223 rewriter 224 .create<GenericOp>(consumer.getLoc(), consumer->getResultTypes(), 225 /*inputs=*/fusedOperands, 226 // TODO: handle outputs. 227 consumer.getOutputs(), 228 rewriter.getArrayAttr(fusedIndexMaps), 229 consumer.iterator_types(), 230 /*doc=*/nullptr, 231 /*library_call=*/nullptr, 232 /*sparse=*/nullptr) 233 .getOperation(); 234 } else { 235 fusedOp = 236 rewriter 237 .create<IndexedGenericOp>( 238 consumer.getLoc(), consumer->getResultTypes(), 239 /*inputs=*/fusedOperands, 240 // TODO: handle outputs. 241 consumer.getOutputs(), rewriter.getArrayAttr(fusedIndexMaps), 242 consumer.iterator_types(), 243 /*doc=*/nullptr, 244 /*library_call=*/nullptr, 245 /*sparse=*/nullptr) 246 .getOperation(); 247 } 248 249 // Construct an AffineMap from consumer loops to producer loops. 250 // consumer loop -> tensor index 251 AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx); 252 // producer loop -> tensor index 253 AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); 254 // tensor index -> producer loop 255 AffineMap invProducerResultIndexMap = 256 inversePermutation(producerResultIndexMap); 257 assert(invProducerResultIndexMap && 258 "expected producer result indexig map to be invertible"); 259 // consumer loop -> producer loop 260 AffineMap consumerToProducerLoopsMap = 261 invProducerResultIndexMap.compose(consumerResultIndexMap); 262 263 generateFusedTensorOpRegion(rewriter, fusedOp.getOperation(), producer, 264 consumer, consumerToProducerLoopsMap, consumerIdx, 265 consumer.getNumLoops()); 266 return SmallVector<Value, 1>(fusedOp->getResults()); 267 } 268 269 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` 270 /// provided, given the shape of the source tensor that corresponds to the 271 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions 272 /// are "row-major" ordered logically. 273 /// 274 /// For example: 275 /// 276 /// %0 = op ... : tensor<?x?x4x5xf32> 277 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` 278 /// 279 /// and reshape: 280 /// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, 281 /// affine_map<(i, j, k, l) -> (j, k, l)>] : 282 /// tensor<?x?x4x5xf32> into tensor<?x?xf32> 283 /// 284 /// would be rewritten into: 285 /// %0 = op ... : tensor<?x?x4x5xf32> 286 /// with output index_map 287 /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` 288 static AffineMap linearizeCollapsedDims(AffineMap sourceMap, 289 ArrayRef<int64_t> sourceShape, 290 ArrayRef<AffineMap> reassociationMaps) { 291 SmallVector<AffineExpr, 4> resultExprs; 292 resultExprs.reserve(reassociationMaps.size()); 293 ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults(); 294 MLIRContext *context = sourceMap.getContext(); 295 296 // Compute the result exprs based on the reassociation maps. 297 for (AffineMap map : reassociationMaps) { 298 ArrayRef<AffineExpr> collapsedDims = map.getResults(); 299 // Assume that they are in-order and contiguous (already checked in 300 // verifier). 301 assert(!collapsedDims.empty()); 302 unsigned startDim = 303 collapsedDims.front().cast<AffineDimExpr>().getPosition(); 304 SmallVector<int64_t, 4> sizes; 305 SmallVector<AffineExpr, 4> dimExprs; 306 for (auto en : 307 llvm::zip(sourceShape.slice(startDim, collapsedDims.size()), 308 sourceExprs.slice(startDim, collapsedDims.size()))) { 309 if (std::get<0>(en) == 1) 310 continue; 311 sizes.push_back(std::get<0>(en)); 312 dimExprs.push_back(std::get<1>(en)); 313 } 314 AffineExpr linearizedExpr = 315 makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); 316 resultExprs.push_back(linearizedExpr); 317 } 318 return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(), 319 resultExprs, context); 320 } 321 322 /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is 323 /// true) or its producer (if `asProducer` is false) given the indexing map at 324 /// its use. 325 static bool isTensorReshapeOpFoldableByLinearization(TensorReshapeOp reshapeOp, 326 AffineMap useIndexMap, 327 bool asProducer) { 328 RankedTensorType returnType = reshapeOp.getResultType(); 329 RankedTensorType operandType = reshapeOp.getSrcType(); 330 // Reshape is fusable with its consumer (i.e. reshape as a producer) when its 331 // operand is of lesser rank than the result. Fusing when operand has higher 332 // rank will require use of mods and divs in the indexing maps of the fused op 333 // which would make it non-invertible. Similarly reshape is fused with its 334 // producer (i.e. reshape as consumer) only if the return type has lesser 335 // rank. 336 if ((asProducer && reshapeOp.getSrcType().hasStaticShape() && 337 returnType.getRank() < operandType.getRank()) || 338 (!asProducer && reshapeOp.getResultType().hasStaticShape() && 339 operandType.getRank() < returnType.getRank())) 340 return false; 341 return useIndexMap.isPermutation(); 342 } 343 344 /// Based on the type of `op` create a linalg op of the same type, i.e. if `op` 345 /// is a linalg.generic operation, the create a `linalg.generic` operation with 346 /// the given `args`. Expects `op` to be `linalg.generic` or 347 /// `linalg.indexed_generic`. 348 template <typename... Args> 349 static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter, 350 Args... args) { 351 if (isa<GenericOp>(op.getOperation())) 352 return rewriter.create<GenericOp>(args...); 353 if (isa<IndexedGenericOp>(op.getOperation())) 354 return rewriter.create<IndexedGenericOp>(args...); 355 llvm_unreachable( 356 "expected only linalg.generic or linalg.indexed_generic ops"); 357 return nullptr; 358 } 359 360 /// Check if the reshape operation is only expansion into/collapsing of 361 /// unit-dimension. 362 static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape, 363 ArrayRef<AffineMap> reassociation) { 364 for (auto &map : reassociation) { 365 unsigned numUnitDims = 0; 366 for (AffineExpr expr : map.getResults()) { 367 unsigned position = expr.cast<AffineDimExpr>().getPosition(); 368 if (expandedShape[position] == 1) 369 numUnitDims++; 370 } 371 if (numUnitDims != map.getNumResults() - 1) 372 return false; 373 } 374 return true; 375 } 376 377 /// Conditions for folding a generic/indexed-generic operation with a reshape op 378 /// by expanding the iteration space dimensionality for tensor operations. These 379 /// are preconditions assumed by `foldReshapeByDimExpansion` which implements 380 /// the following fusion pattern. 381 /// 382 /// Consider 383 /// 384 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>) 385 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 386 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 387 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] 388 /// %d = linalg.tensor_reshape %c 389 /// [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, 390 /// affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, 391 /// affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] 392 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 393 /// 394 /// The reshape can be folded into the `linalgOp` if the 395 /// generic/indexed-generic op loop dimensionality is increased to match the 396 /// result (operand) of the tensor_reshape when the reshape is expanding 397 /// (folding). The indexing_map of the fused tensor in the `linalgOp` and the 398 /// reassociation map helps compute the indexing maps of the modified op. For 399 /// the above example, based on the reassociation map it can be concluded that 400 /// 401 /// - The loop used to access the first dimension of the fused tensor is split 402 /// into two. 403 /// - The loop used to access the second dimension of the fused tensor is kept 404 /// as is. 405 /// - The loop used to access the third dimension of the fused tensor is split 406 /// into three. 407 /// 408 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified 409 /// op, then 410 /// 411 /// d0 -> e0, e1 412 /// d1 -> e2, e3, e4 413 /// d2 -> e5 414 /// 415 /// substituting this, the generic op can be rewritten as 416 /// 417 /// %d = linalg.generic ins(%0, %1 : ) 418 /// indexing_maps = 419 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, 420 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, 421 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] 422 /// 423 /// Since operands to the linalg generic are now 5D, reshapes can be introduced 424 /// to make it consistent 425 /// 426 /// %0 = linalg.tensor_reshape %a 427 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e2), 428 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e3, e4), 429 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e5)] 430 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 431 /// %1 = linalg.tensor_reshape %b 432 /// [affine_map<(e0, e1, e2, e3) -> (e0, e1, e2), 433 /// affine_map<(e0, e1, e2, e3) -> (e3)] 434 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 435 /// 436 /// The added reshapes are again expanding patterns, so they will get fused 437 /// with its producers if possible. 438 static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, 439 unsigned fusedTensorIndex) { 440 // Is fusable only if: 441 // - The linalgOp is a generic op, or an indexed_generic. 442 // - All the indexing maps for operands and results in linalgOp are projected 443 // permutations. 444 // - The fused tensor is not a scalar. 445 // - All the loops in linalgOp are parallel loops. 446 return isa<GenericOp, IndexedGenericOp>(linalgOp.getOperation()) && 447 linalgOp.hasTensorSemantics() && 448 llvm::all_of(linalgOp.indexing_maps().getValue(), 449 [](Attribute attr) { 450 return attr.cast<AffineMapAttr>() 451 .getValue() 452 .isProjectedPermutation(); 453 }) && 454 linalgOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 && 455 llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) { 456 return attr.cast<StringAttr>().getValue() == 457 getParallelIteratorTypeName(); 458 }); 459 } 460 461 namespace { 462 /// Information needed to expand a generic/indexed_generic operation to fold the 463 /// reshape with it. 464 class ExpansionInfo { 465 public: 466 // Computes the mapping from original dimensions of the op to the dimensions 467 // of the expanded op given the `indexingMap` of the fused operand/result of 468 // the generic/indexed_generic op, the `reassocationMaps` of the reshape op 469 // and the shape of the expanded op. 470 LogicalResult compute(LinalgOp linalgOp, unsigned fusedTensorIndex, 471 ArrayRef<AffineMap> reassociationMaps, 472 ArrayRef<int64_t> expandedShape); 473 unsigned getOrigOpNumDims() const { return reassociation.size(); } 474 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } 475 ReassociationIndicesRef getExpandedDims(unsigned i) const { 476 return reassociation[i]; 477 } 478 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { 479 return expandedShapeMap[i]; 480 } 481 482 private: 483 /// Reassociation from the dimensions in the original operation to the 484 /// dimension of the expanded operation. 485 SmallVector<ReassociationIndices, 4> reassociation; 486 /// Mapping from extent of loops in the original operation, to the extent of 487 /// loops in the expanded operation. 488 SmallVector<SmallVector<int64_t, 4>, 4> expandedShapeMap; 489 unsigned expandedOpNumDims; 490 }; 491 } // namespace 492 493 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, 494 unsigned fusedTensorIndex, 495 ArrayRef<AffineMap> reassociationMaps, 496 ArrayRef<int64_t> expandedShape) { 497 if (reassociationMaps.empty()) 498 return failure(); 499 AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex); 500 501 Optional<SmallVector<int64_t, 4>> originalLoopRange = 502 getStaticLoopRanges(linalgOp); 503 if (!originalLoopRange) 504 return linalgOp.emitError("unable to find loop range for operation"); 505 506 reassociation.clear(); 507 expandedShapeMap.clear(); 508 // Compute the number of dimension in the expanded op that correspond to each 509 // dimension of the original op. 510 SmallVector<unsigned, 4> numExpandedDims(fusedIndexMap.getNumDims(), 1); 511 expandedShapeMap.resize(fusedIndexMap.getNumDims()); 512 for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { 513 unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition(); 514 AffineMap foldedDims = reassociationMaps[resultExpr.index()]; 515 numExpandedDims[pos] = foldedDims.getNumResults(); 516 ArrayRef<int64_t> shape = 517 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); 518 expandedShapeMap[pos].assign(shape.begin(), shape.end()); 519 } 520 // The remaining dimensions remain the same. 521 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims())) 522 if (expandedShapeMap[i].empty()) 523 expandedShapeMap[i] = {(*originalLoopRange)[i]}; 524 525 // Compute reassociation map from the original op to the expanded op. 526 unsigned sum = 0; 527 reassociation.reserve(fusedIndexMap.getNumDims()); 528 for (auto numFoldedDim : llvm::enumerate(numExpandedDims)) { 529 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value()); 530 reassociation.emplace_back(seq.begin(), seq.end()); 531 sum += numFoldedDim.value(); 532 } 533 expandedOpNumDims = sum; 534 return success(); 535 } 536 537 /// To expand an indexed_generic operation, the body of the indexed generic op 538 /// need to be modified appropriately. Specifically, uses of arguments for 539 /// induction variables in the original operation need to be replaced with 540 /// linearization of the corresponding arguments in the expanded op. That 541 /// requires the shape of the expanded dimensions (at least all but the most 542 /// significant. For now check that these are all statically sized. Note that 543 /// this could be extended to handle dynamic case, but the implementation below 544 /// uses `affine.apply` which seems to have issues when the shapes are not 545 /// static. 546 LogicalResult isIndexedGenericOpExpandable(LinalgOp linalgOp, 547 const ExpansionInfo &expansionInfo) { 548 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 549 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 550 if (expandedShape.size() == 1) 551 continue; 552 for (int64_t shape : expandedShape.drop_front()) { 553 if (ShapedType::isDynamic(shape)) { 554 return linalgOp.emitError( 555 "unable to fuse indexed generic op where the expanded dim is " 556 "dynamic"); 557 } 558 } 559 } 560 return success(); 561 } 562 563 /// Return the indexing map to use in the expanded op for a given the 564 /// `indexingMap` of the original operation. 565 static AffineMap 566 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, 567 const ExpansionInfo &expansionInfo) { 568 SmallVector<AffineExpr, 4> newExprs; 569 for (AffineExpr expr : indexingMap.getResults()) { 570 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 571 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>( 572 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { 573 return builder.getAffineDimExpr(static_cast<unsigned>(v)); 574 })); 575 newExprs.append(expandedExprs.begin(), expandedExprs.end()); 576 } 577 return AffineMap::get(expansionInfo.getExpandedOpNumDims(), 578 indexingMap.getNumSymbols(), newExprs, 579 builder.getContext()); 580 } 581 582 /// Return the type of the operand/result to use in the expanded op given the 583 /// type in the original op. 584 static RankedTensorType getExpandedType(RankedTensorType originalType, 585 AffineMap indexingMap, 586 const ExpansionInfo &expansionInfo) { 587 SmallVector<int64_t, 4> expandedShape; 588 for (AffineExpr expr : indexingMap.getResults()) { 589 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 590 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); 591 expandedShape.append(dimExpansion.begin(), dimExpansion.end()); 592 } 593 return RankedTensorType::get(expandedShape, originalType.getElementType()); 594 } 595 596 /// Returns the reassociation maps to use in the `linalg.tensor_reshape` 597 /// operation to convert the operands of the origial operation to operands of 598 /// the expanded operation. The same method is used to compute the 599 /// `linalg.tensor_reshape` used to collapse the result of the expanded op to 600 /// get the value that can replace all uses of the results of the original op. 601 static SmallVector<ReassociationIndices, 4> 602 getReassociationForExpansion(AffineMap indexingMap, 603 const ExpansionInfo &expansionInfo) { 604 SmallVector<ReassociationIndices, 4> reassociation; 605 unsigned numReshapeDims = 0; 606 for (AffineExpr expr : indexingMap.getResults()) { 607 unsigned dim = expr.cast<AffineDimExpr>().getPosition(); 608 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); 609 auto indices = llvm::to_vector<2>( 610 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims)); 611 reassociation.emplace_back(std::move(indices)); 612 numReshapeDims += numExpandedDims; 613 } 614 return reassociation; 615 } 616 617 /// Build the body of the expanded IndexedGenericOp. The arguments for the 618 /// induction variables of the original operation need to be recovered by 619 /// linearizing the arguments of the corresponding dimensions of the expanded 620 /// op. For now it is assumed that the shapes of the expanded op needed for 621 /// linearization are static. 622 static void buildExpandedIndexedGenericOpRegion( 623 PatternRewriter &rewriter, Location loc, Region &originalOpRegion, 624 Region &fusedOpRegion, const ExpansionInfo &expansionInfo) { 625 assert(fusedOpRegion.empty() && "expected fused op to have empty region"); 626 // Create an entry block in the fused region with same number of arguments 627 // as the fused op 628 Block *fusedEntryBlock = new Block; 629 fusedOpRegion.push_back(fusedEntryBlock); 630 rewriter.cloneRegionBefore(originalOpRegion, fusedOpRegion, 631 fusedOpRegion.end()); 632 633 // Merge the entry block of the fused op with the cloned blocks. For this 634 // compute the value for arguments of the region in the original operation 635 // in terms of the arguments of the fused op. Since the original operation 636 // is expanded, the expanded dimensions need to be folded back to get the 637 // replacement value for the arguments corresponding to interation index. 638 // For now this expects that all the loop ranges are constants, which is 639 // true if the shapes are all static. This has already been checked in the 640 // precondition. 641 using namespace edsc::op; 642 using namespace edsc::intrinsics; 643 OpBuilder::InsertionGuard guard(rewriter); 644 SmallVector<Value, 4> argReplacements(originalOpRegion.getNumArguments()); 645 rewriter.setInsertionPointToStart(fusedEntryBlock); 646 edsc::ScopedContext scopedContext(rewriter, loc); 647 IndexType indexType = rewriter.getIndexType(); 648 for (auto i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 649 Value linearizedIndex = fusedEntryBlock->addArgument(indexType); 650 ArrayRef<int64_t> expandedDimsShape = 651 expansionInfo.getExpandedShapeOfDim(i).drop_front(); 652 for (unsigned shape : expandedDimsShape) { 653 assert(!ShapedType::isDynamic(shape)); 654 linearizedIndex = linearizedIndex * std_constant_index(shape); 655 linearizedIndex = 656 linearizedIndex + fusedEntryBlock->addArgument(indexType); 657 } 658 argReplacements[i] = linearizedIndex; 659 } 660 for (auto i : llvm::seq<unsigned>(expansionInfo.getOrigOpNumDims(), 661 argReplacements.size())) { 662 argReplacements[i] = 663 fusedEntryBlock->addArgument(originalOpRegion.getArgument(i).getType()); 664 } 665 rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock, 666 argReplacements); 667 } 668 669 /// Implements the fusion of a tensor_reshape op and a generic/indexed_generic 670 /// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those 671 /// conditions have been satisfied. 672 static Optional<SmallVector<Value, 1>> 673 fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp, 674 unsigned fusedTensorIndex, 675 PatternRewriter &rewriter) { 676 assert(isFusableWithReshapeByDimExpansion(linalgOp, fusedTensorIndex) && 677 "preconditions for fuse operation failed"); 678 // Check if reshape is expanding or collapsing. 679 bool isExpanding = 680 reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank(); 681 RankedTensorType expandedType = 682 isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType(); 683 684 ExpansionInfo expansionInfo; 685 if (failed(expansionInfo.compute(linalgOp, fusedTensorIndex, 686 reshapeOp.getReassociationMaps(), 687 expandedType.getShape()))) 688 return llvm::None; 689 690 if (isa<IndexedGenericOp>(linalgOp.getOperation()) && 691 failed(isIndexedGenericOpExpandable(linalgOp, expansionInfo))) 692 return llvm::None; 693 694 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( 695 llvm::map_range(linalgOp.getIndexingMaps(), [&](AffineMap m) { 696 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); 697 })); 698 699 SmallVector<Value, 4> expandedOpOperands; 700 for (auto operand : llvm::enumerate(linalgOp.getInputs())) { 701 if (operand.index() == fusedTensorIndex) { 702 expandedOpOperands.push_back(reshapeOp.src()); 703 continue; 704 } 705 AffineMap indexingMap = linalgOp.getInputIndexingMap(operand.index()); 706 RankedTensorType expandedOperandType = 707 getExpandedType(operand.value().getType().cast<RankedTensorType>(), 708 indexingMap, expansionInfo); 709 if (expandedOperandType != operand.value().getType()) { 710 // Reshape the operand to get the right type. 711 SmallVector<ReassociationIndices, 4> reassociation = 712 getReassociationForExpansion(indexingMap, expansionInfo); 713 expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>( 714 linalgOp.getLoc(), expandedOperandType, operand.value(), 715 reassociation)); 716 continue; 717 } 718 expandedOpOperands.push_back(operand.value()); 719 } 720 721 Location loc = linalgOp.getLoc(); 722 SmallVector<Value, 1> outputs; 723 for (auto result : llvm::enumerate(linalgOp.getOutputs())) { 724 AffineMap indexingMap = linalgOp.getOutputIndexingMap(result.index()); 725 RankedTensorType expandedOutputType = 726 getExpandedType(result.value().getType().cast<RankedTensorType>(), 727 indexingMap, expansionInfo); 728 if (expandedOutputType != result.value().getType()) { 729 SmallVector<ReassociationIndices, 4> reassociation = 730 getReassociationForExpansion(indexingMap, expansionInfo); 731 outputs.push_back(rewriter.create<TensorReshapeOp>( 732 linalgOp.getLoc(), expandedOutputType, result.value(), 733 reassociation)); 734 } 735 } 736 737 // The iterator types of the expanded op are all parallel. 738 SmallVector<StringRef, 4> iteratorTypes(expansionInfo.getExpandedOpNumDims(), 739 getParallelIteratorTypeName()); 740 741 TypeRange resultTypes = ValueRange(outputs).getTypes(); 742 LinalgOp fusedOp = createLinalgOpOfSameType( 743 linalgOp, rewriter, linalgOp.getLoc(), resultTypes, 744 /*inputs=*/expandedOpOperands, outputs, expandedOpIndexingMaps, 745 iteratorTypes); 746 Region &fusedRegion = fusedOp->getRegion(0); 747 Region &originalRegion = linalgOp->getRegion(0); 748 749 if (isa<GenericOp>(linalgOp.getOperation())) { 750 rewriter.cloneRegionBefore(originalRegion, fusedRegion, 751 fusedRegion.begin()); 752 } else { 753 assert(isa<IndexedGenericOp>(linalgOp.getOperation())); 754 buildExpandedIndexedGenericOpRegion(rewriter, loc, originalRegion, 755 fusedRegion, expansionInfo); 756 } 757 758 // Reshape the result values to their original shape if this is a collapsing 759 // reshape folded into its consumer. 760 SmallVector<Value, 1> resultVals; 761 for (auto result : llvm::enumerate(linalgOp->getResults())) { 762 if (!isExpanding && 763 resultTypes[result.index()] != result.value().getType()) { 764 SmallVector<ReassociationIndices, 4> reassociation = 765 getReassociationForExpansion( 766 linalgOp.getOutputIndexingMap(result.index()), expansionInfo); 767 resultVals.push_back(rewriter.create<TensorReshapeOp>( 768 linalgOp.getLoc(), result.value().getType(), 769 fusedOp->getResult(result.index()), reassociation)); 770 } else { 771 resultVals.push_back(fusedOp->getResult(result.index())); 772 } 773 } 774 // Assuming a single result. 775 return resultVals; 776 } 777 778 namespace { 779 780 /// Pattern to fold tensor_reshape op with its consumer by using the source of 781 /// the reshape op as the operand in the consumer (instead of the result of the 782 /// tensor_reshapeop) when the tensor_reshape op is collapsing. The 783 /// corresponding index map in the consumer needs to be modified to linearize 784 /// the folded dimension. 785 /// 786 /// For example, 787 /// 788 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 789 /// %0 = linalg.tensor_reshape %arg0 790 /// [affine_map<(i, j, k, l) -> (i)>, affine_map<(i, j, k, l) -> (j, k)>, 791 /// affine_map<(i, j, k, l) -> (l)>] 792 /// tensor<?x?x?xf32> into tensor<?x?x4x?xf32> 793 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } 794 /// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ... 795 /// -> tensor<?x?x4x?xf32> 796 /// 797 /// can be folded into 798 /// 799 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> 800 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 801 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } 802 /// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ... 803 /// -> tensor<?x?x4x?xf32> 804 template <typename LinalgOpTy, bool foldUnitDimReshapesOnly> 805 struct FoldProducerReshapeOpByLinearization 806 : public OpRewritePattern<LinalgOpTy> { 807 using OpRewritePattern<LinalgOpTy>::OpRewritePattern; 808 809 LogicalResult matchAndRewrite(LinalgOpTy op, 810 PatternRewriter &rewriter) const override { 811 if (!op.hasTensorSemantics()) 812 return failure(); 813 LinalgOp linalgOp = cast<LinalgOp>(op.getOperation()); 814 for (auto operand : llvm::enumerate(linalgOp.getInputs())) { 815 TensorReshapeOp reshapeOp = 816 operand.value().getDefiningOp<TensorReshapeOp>(); 817 if (!reshapeOp || 818 !isTensorReshapeOpFoldableByLinearization( 819 reshapeOp, linalgOp.getInputIndexingMap(operand.index()), 820 /*asProducer =*/true) || 821 (foldUnitDimReshapesOnly && 822 !isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(), 823 reshapeOp.getReassociationMaps()))) 824 continue; 825 826 // Compute the fused operands list, 827 SmallVector<Value, 2> fusedOperands(linalgOp.getInputs()); 828 fusedOperands[operand.index()] = reshapeOp.src(); 829 fusedOperands.append(linalgOp.getOutputs().begin(), 830 linalgOp.getOutputs().end()); 831 832 // Compute indexing_maps for the fused operation. The indexing_maps for 833 // the operands of the consumers that arent fused are the same. 834 SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>( 835 op.indexing_maps().template getAsValueRange<AffineMapAttr>()); 836 837 // Accepted consumer maps are either identity or permutation. 838 auto invMap = inversePermutation(fusedIndexMaps[operand.index()]); 839 840 // Compute the indexing map to use for the result of the producer. 841 AffineMap modifiedMap = 842 linearizeCollapsedDims(invMap, reshapeOp.getResultType().getShape(), 843 reshapeOp.getReassociationMaps()); 844 for (AffineExpr expr : modifiedMap.getResults()) { 845 if (!expr.isPureAffine()) 846 return failure(); 847 } 848 fusedIndexMaps[operand.index()] = modifiedMap; 849 850 // Further check that the resulting index maps can be fused and 851 // inverted. Without this the resultant op is not legal. 852 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) 853 return op.emitRemark("fused op loop bound computation failed"); 854 855 rewriter.startRootUpdate(op); 856 op->setOperands(fusedOperands); 857 op.indexing_mapsAttr(rewriter.getAffineMapArrayAttr(fusedIndexMaps)); 858 rewriter.finalizeRootUpdate(op); 859 if (reshapeOp.use_empty()) 860 rewriter.eraseOp(reshapeOp); 861 return success(); 862 } 863 return failure(); 864 } 865 }; 866 867 /// Pattern to fuse a tensor_reshape op with its consumer 868 /// generic/indexed_generic op, when the reshape op is collapsing 869 /// dimensions. The dimensionality of the loop in the consumer is expanded. 870 template <typename GenericOpTy> 871 struct FoldWithProducerReshapeOpByExpansion 872 : public OpRewritePattern<GenericOpTy> { 873 using OpRewritePattern<GenericOpTy>::OpRewritePattern; 874 875 LogicalResult matchAndRewrite(GenericOpTy genericOp, 876 PatternRewriter &rewriter) const override { 877 LinalgOp linalgOp = cast<LinalgOp>(genericOp.getOperation()); 878 for (auto operand : llvm::enumerate(linalgOp.getInputs())) { 879 TensorReshapeOp reshapeOp = 880 operand.value().getDefiningOp<TensorReshapeOp>(); 881 if (!reshapeOp) 882 continue; 883 884 // Fold only if 885 // - The tensor reshape op is folding. 886 // - All constraints of fusing with reshape by expansion are met. 887 if (reshapeOp.getSrcType().getRank() < 888 reshapeOp.getResultType().getRank() || 889 !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) || 890 isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), 891 reshapeOp.getReassociationMaps())) 892 continue; 893 894 Optional<SmallVector<Value, 1>> replacementValues = 895 fuseWithReshapeByExpansion(linalgOp, reshapeOp, operand.index(), 896 rewriter); 897 if (!replacementValues) 898 return failure(); 899 rewriter.replaceOp(genericOp, replacementValues.getValue()); 900 if (reshapeOp.use_empty()) 901 rewriter.eraseOp(reshapeOp); 902 return success(); 903 } 904 return failure(); 905 } 906 }; 907 908 /// Pattern to fold tensor_reshape op with its producer. The corresponding index 909 /// map in the consumer needs to be modified to linearize the folded dimension. 910 template <bool foldUnitDimReshapesOnly> 911 struct FoldConsumerReshapeOpByLinearization 912 : public OpRewritePattern<TensorReshapeOp> { 913 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 914 915 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 916 PatternRewriter &rewriter) const override { 917 LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>(); 918 if (!producer || 919 !isa<GenericOp, IndexedGenericOp>(producer.getOperation()) || 920 !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 || 921 !isTensorReshapeOpFoldableByLinearization( 922 reshapeOp, producer.getOutputIndexingMap(0), 923 /*asProducer =*/false) || 924 (foldUnitDimReshapesOnly && 925 !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), 926 reshapeOp.getReassociationMaps()))) 927 return failure(); 928 // The indexing_maps for the operands of the fused operation are same as 929 // those for the operands of the producer. 930 SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>( 931 producer.indexing_maps().getAsValueRange<AffineMapAttr>()); 932 933 auto invMap = inversePermutation(producer.getOutputIndexingMap(0)); 934 935 // Compute the indexing map to use for the operand of the producer. 936 AffineMap modifiedMap = 937 linearizeCollapsedDims(invMap, reshapeOp.getSrcType().getShape(), 938 reshapeOp.getReassociationMaps()); 939 for (AffineExpr expr : modifiedMap.getResults()) { 940 if (!expr.isPureAffine()) 941 return producer.emitRemark("fused op indexing map is not affine"); 942 } 943 fusedIndexMaps.back() = modifiedMap; 944 945 // Further check that the resulting index maps can be fused and 946 // inverted. Without this the resultant op is not legal. 947 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) 948 return reshapeOp.emitRemark("fused op loop bound computation failed"); 949 950 Location loc = producer.getLoc(); 951 Value output = rewriter.create<TensorReshapeOp>( 952 loc, producer.getOutputs()[0], reshapeOp.getReassociationExprs()); 953 LinalgOp fusedOp = createLinalgOpOfSameType( 954 producer, rewriter, loc, reshapeOp.getResultType(), 955 /*inputs=*/producer.getInputs(), 956 // TODO: handle outputs. 957 /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 958 producer.iterator_types(), 959 /*doc=*/nullptr, 960 /*library_call=*/nullptr, 961 /*sparse=*/nullptr); 962 auto &fusedRegion = fusedOp->getRegion(0); 963 rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion, 964 fusedRegion.begin()); 965 rewriter.replaceOp(reshapeOp, fusedOp->getResults()); 966 if (producer.use_empty()) 967 rewriter.eraseOp(producer); 968 return success(); 969 } 970 }; 971 972 /// Pattern to fold a tensor_reshape op with its producer generic op if the 973 /// tensor_reshape op is expanding, by expanding the dimensionality of the loop 974 /// in the producer op. 975 struct FoldReshapeWithGenericOpByExpansion 976 : public OpRewritePattern<TensorReshapeOp> { 977 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 978 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 979 PatternRewriter &rewriter) const override { 980 // Fold only if 981 // - The tensor reshape op is a expanding case. 982 // - All constraints of fusing with reshape by expansion are met. 983 if (reshapeOp.getSrcType().getRank() > reshapeOp.getResultType().getRank()) 984 return failure(); 985 LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>(); 986 if (!producer || producer.getNumOutputs() != 1 || 987 !isFusableWithReshapeByDimExpansion(producer, 988 producer.getNumInputs()) || 989 isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(), 990 reshapeOp.getReassociationMaps())) 991 return failure(); 992 Optional<SmallVector<Value, 1>> replacementValues = 993 fuseWithReshapeByExpansion(producer, reshapeOp, producer.getNumInputs(), 994 rewriter); 995 if (!replacementValues) 996 return failure(); 997 rewriter.replaceOp(reshapeOp, replacementValues.getValue()); 998 if (producer.use_empty()) 999 rewriter.eraseOp(producer); 1000 return success(); 1001 } 1002 }; 1003 1004 /// Pattern to fold a GenericOp/IndexedGenericOp with a splat constant. 1005 template <typename LinalgOpTy> 1006 struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> { 1007 using OpRewritePattern<LinalgOpTy>::OpRewritePattern; 1008 1009 LogicalResult matchAndRewrite(LinalgOpTy op, 1010 PatternRewriter &rewriter) const override { 1011 if (!op.hasTensorSemantics()) 1012 return failure(); 1013 LinalgOp linalgOp = cast<LinalgOp>(op.getOperation()); 1014 for (auto operand : llvm::enumerate(linalgOp.getInputs())) { 1015 ConstantOp constantOp = operand.value().getDefiningOp<ConstantOp>(); 1016 if (!constantOp || 1017 !constantOp.value().cast<DenseElementsAttr>().isSplat()) 1018 continue; 1019 1020 // The indexing_maps for the operands of the fused operation are same as 1021 // those for the operands of the linalgOp without the indexing map at 1022 // operand.index() 1023 SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>( 1024 linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>()); 1025 fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index())); 1026 1027 // The operands list is same as the linalgOp with the argument for 1028 // constant index dropped. 1029 SmallVector<Value, 4> fusedOperands(linalgOp.getInputs()); 1030 fusedOperands.erase(std::next(fusedOperands.begin(), operand.index())); 1031 1032 // Create a constant scalar value from the splat constant. 1033 Value scalarConstant = rewriter.create<ConstantOp>( 1034 constantOp.getLoc(), 1035 constantOp.value().cast<DenseElementsAttr>().getSplatValue()); 1036 1037 LinalgOp fusedOp = createLinalgOpOfSameType( 1038 linalgOp, rewriter, rewriter.getUnknownLoc(), 1039 linalgOp->getResultTypes(), 1040 /*inputs=*/fusedOperands, 1041 /*outputs=*/linalgOp.getOutputs(), 1042 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1043 linalgOp.iterator_types(), 1044 /*doc=*/nullptr, 1045 /*library_call=*/nullptr, 1046 /*sparse=*/nullptr); 1047 1048 // Map the block argument corresponding to the replaced argument with the 1049 // scalar constant. 1050 Region &linalgOpRegion = linalgOp->getRegion(0); 1051 Block &entryBlock = *linalgOpRegion.begin(); 1052 unsigned argIndex = entryBlock.getNumArguments() - 1053 linalgOp.getNumShapedOperands() + operand.index(); 1054 BlockAndValueMapping mapping; 1055 mapping.map(entryBlock.getArgument(argIndex), scalarConstant); 1056 Region &fusedRegion = fusedOp->getRegion(0); 1057 rewriter.cloneRegionBefore(linalgOpRegion, fusedRegion, 1058 fusedRegion.begin(), mapping); 1059 rewriter.replaceOp(linalgOp, fusedOp->getResults()); 1060 if (constantOp.use_empty()) 1061 rewriter.eraseOp(constantOp); 1062 return success(); 1063 } 1064 return failure(); 1065 } 1066 }; 1067 } // namespace 1068 1069 Optional<SmallVector<Value, 1>> 1070 mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, 1071 OpOperand &consumerOpOperand) { 1072 Operation *producer = consumerOpOperand.get().getDefiningOp(); 1073 if (!producer || producer->getNumResults() != 1) 1074 return llvm::None; 1075 1076 // Fuse when consumer is GenericOp or IndexedGenericOp. 1077 if (!isa<GenericOp, IndexedGenericOp>(consumerOpOperand.getOwner()) || 1078 !isa<GenericOp, IndexedGenericOp>(producer)) 1079 return llvm::None; 1080 1081 return fuseTensorOpsImpl(cast<LinalgOp>(producer), consumerOpOperand, 1082 rewriter); 1083 } 1084 1085 namespace { 1086 /// Patterns to fuse a generic op, with the producer of its operands. 1087 template <typename LinalgOpTy> 1088 struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> { 1089 using OpRewritePattern<LinalgOpTy>::OpRewritePattern; 1090 1091 LogicalResult matchAndRewrite(LinalgOpTy op, 1092 PatternRewriter &rewriter) const override { 1093 // Find the first operand that is defined by another generic op on tensors. 1094 for (OpOperand &opOperand : op.getShapedOpOperands()) { 1095 Operation *producer = opOperand.get().getDefiningOp(); 1096 if (!producer) 1097 continue; 1098 Optional<SmallVector<Value, 1>> fusedOpResults = 1099 fuseTensorOps(rewriter, opOperand); 1100 if (fusedOpResults) { 1101 rewriter.replaceOp(op, *fusedOpResults); 1102 if (producer->use_empty()) 1103 rewriter.eraseOp(producer); 1104 return success(); 1105 } 1106 } 1107 return failure(); 1108 } 1109 }; 1110 1111 /// Pass that fuses generic ops on tensors. Used only for testing. 1112 struct FusionOfTensorOpsPass 1113 : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> { 1114 void runOnOperation() override { 1115 OwningRewritePatternList patterns; 1116 Operation *op = getOperation(); 1117 populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns); 1118 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); 1119 } 1120 }; 1121 1122 /// Pass to test folding of reshape op with generic/indexed_generic ops by 1123 /// linearization. 1124 struct FoldReshapeOpsByLinearizationPass 1125 : public LinalgFoldReshapeOpsByLinearizationBase< 1126 FoldReshapeOpsByLinearizationPass> { 1127 void runOnOperation() override { 1128 OwningRewritePatternList patterns; 1129 Operation *op = getOperation(); 1130 populateFoldReshapeOpsByLinearizationPatterns(op->getContext(), patterns); 1131 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); 1132 } 1133 }; 1134 1135 } // namespace 1136 1137 void mlir::populateFoldReshapeOpsByLinearizationPatterns( 1138 MLIRContext *context, OwningRewritePatternList &patterns) { 1139 patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, false>, 1140 FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>, 1141 FoldConsumerReshapeOpByLinearization<false>>(context); 1142 } 1143 1144 void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( 1145 MLIRContext *context, OwningRewritePatternList &patterns) { 1146 patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, true>, 1147 FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>, 1148 FoldConsumerReshapeOpByLinearization<true>>(context); 1149 } 1150 1151 void mlir::populateFoldReshapeOpsByExpansionPatterns( 1152 MLIRContext *context, OwningRewritePatternList &patterns) { 1153 patterns.insert<FoldReshapeWithGenericOpByExpansion, 1154 FoldWithProducerReshapeOpByExpansion<GenericOp>, 1155 FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>( 1156 context); 1157 } 1158 1159 void mlir::populateLinalgTensorOpsFusionPatterns( 1160 MLIRContext *context, OwningRewritePatternList &patterns) { 1161 patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>, 1162 FoldSplatConstants<GenericOp>, 1163 FoldSplatConstants<IndexedGenericOp>>(context); 1164 populateFoldReshapeOpsByExpansionPatterns(context, patterns); 1165 GenericOp::getCanonicalizationPatterns(patterns, context); 1166 IndexedGenericOp::getCanonicalizationPatterns(patterns, context); 1167 TensorReshapeOp::getCanonicalizationPatterns(patterns, context); 1168 } 1169 1170 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() { 1171 return std::make_unique<FusionOfTensorOpsPass>(); 1172 } 1173 1174 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() { 1175 return std::make_unique<FoldReshapeOpsByLinearizationPass>(); 1176 } 1177