1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/OpDefinition.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 #include <iterator> 19 #include <memory> 20 21 using namespace mlir; 22 using namespace mlir::linalg; 23 24 static Value sourceMaterializationCallback(OpBuilder &builder, Type type, 25 ValueRange inputs, Location loc) { 26 assert(inputs.size() == 1); 27 // A detensored value is converted back by creating a new tensor from its 28 // element(s). 29 auto createNewTensorOp = builder.create<tensor::FromElementsOp>( 30 loc, inputs[0].getType(), inputs[0]); 31 32 // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to 33 // a tensor<dtype> instead. 34 return builder.create<linalg::TensorCollapseShapeOp>( 35 loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{}); 36 } 37 38 namespace { 39 /// Defines the criteria a TensorType must follow in order to be considered 40 /// "detensorable". 41 /// 42 /// NOTE: For now, only 0-D tensors are supported. 43 /// 44 /// Returns true if tensorType can be detensored. 45 bool canBeDetensored(TensorType tensorType) { 46 return tensorType.hasRank() && tensorType.getRank() == 0; 47 } 48 49 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { 50 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op); 51 return genericOp && 52 llvm::all_of( 53 genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) { 54 return !typeConverter.isLegal(opOperand->get().getType()); 55 }); 56 } 57 58 /// A conversion patttern for detensoring `linalg.generic` ops. 59 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> { 60 public: 61 using OpConversionPattern::OpConversionPattern; 62 LogicalResult 63 matchAndRewrite(GenericOp op, OpAdaptor adaptor, 64 ConversionPatternRewriter &rewriter) const override { 65 Block *originalBlock = op->getBlock(); 66 67 // Gather some information about the op before inling its region. 68 Block *opEntryBlock = &*op.region().begin(); 69 YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator()); 70 71 // Split the op's region before the op. This way, we have a clear insertion 72 // point in which the op can be inlined. 73 Block *newBlock = originalBlock->splitBlock(op); 74 rewriter.inlineRegionBefore(op.region(), newBlock); 75 // Now that op's region is inlined, the operands of its YieldOp are mapped 76 // to the materialized target values. Therefore, we can replace the op's 77 // uses with those of its YielOp's operands. 78 rewriter.replaceOp(op, yieldOp->getOperands()); 79 80 // No need for these intermediate blocks, merge them into 1. 81 rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands()); 82 rewriter.mergeBlocks(newBlock, originalBlock, {}); 83 84 rewriter.eraseOp(&*Block::iterator(yieldOp)); 85 86 return success(); 87 } 88 }; 89 90 /// A conversion pattern for detensoring internal (non-entry) blocks within a 91 /// function. 92 struct FunctionNonEntryBlockConversion : public ConversionPattern { 93 FunctionNonEntryBlockConversion(StringRef functionLikeOpName, 94 MLIRContext *ctx, TypeConverter &converter, 95 DenseSet<BlockArgument> blockArgsToDetensor) 96 : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx), 97 blockArgsToDetensor(blockArgsToDetensor) {} 98 99 LogicalResult 100 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 101 ConversionPatternRewriter &rewriter) const override { 102 rewriter.startRootUpdate(op); 103 Region ®ion = function_like_impl::getFunctionBody(op); 104 SmallVector<TypeConverter::SignatureConversion, 2> conversions; 105 106 for (Block &block : llvm::drop_begin(region, 1)) { 107 conversions.emplace_back(block.getNumArguments()); 108 TypeConverter::SignatureConversion &back = conversions.back(); 109 110 for (BlockArgument blockArgument : block.getArguments()) { 111 int idx = blockArgument.getArgNumber(); 112 113 if (blockArgsToDetensor.count(blockArgument)) 114 back.addInputs(idx, {getTypeConverter()->convertType( 115 block.getArgumentTypes()[idx])}); 116 else 117 back.addInputs(idx, {block.getArgumentTypes()[idx]}); 118 } 119 } 120 121 if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, 122 conversions))) { 123 rewriter.cancelRootUpdate(op); 124 return failure(); 125 } 126 127 rewriter.finalizeRootUpdate(op); 128 return success(); 129 } 130 131 private: 132 const DenseSet<BlockArgument> blockArgsToDetensor; 133 }; 134 135 class DetensorizeTypeConverter : public TypeConverter { 136 public: 137 DetensorizeTypeConverter() { 138 addConversion([](Type type) { return type; }); 139 140 // A TensorType that can be detensored, is converted to the underlying 141 // element type. 142 addConversion([](TensorType tensorType) -> Type { 143 if (canBeDetensored(tensorType)) 144 return tensorType.getElementType(); 145 146 return tensorType; 147 }); 148 149 // A tensor value is detensoried by extracting its element(s). 150 addTargetMaterialization([](OpBuilder &builder, Type type, 151 ValueRange inputs, Location loc) -> Value { 152 return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{}); 153 }); 154 155 addSourceMaterialization(sourceMaterializationCallback); 156 addArgumentMaterialization(sourceMaterializationCallback); 157 } 158 }; 159 160 /// Canonicalizes the pattern of the form 161 /// 162 /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> 163 /// %reshaped_tensor = linalg.tensor_collapse_shape %tensor [] 164 /// : tensor<1xi32> into tensor<i32> 165 /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32> 166 /// 167 /// to just %element. 168 struct ExtractFromReshapeFromElements 169 : public OpRewritePattern<tensor::ExtractOp> { 170 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 171 172 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 173 PatternRewriter &rewriter) const final { 174 if (!extract.indices().empty()) 175 return failure(); 176 177 auto tensorReshape = 178 extract.tensor().getDefiningOp<TensorCollapseShapeOp>(); 179 if (tensorReshape == nullptr) 180 return failure(); 181 182 auto tensorFromElements = 183 tensorReshape.getOperand() 184 .getDefiningOp<mlir::tensor::FromElementsOp>(); 185 if (tensorFromElements == nullptr) 186 return failure(); 187 188 rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); 189 return success(); 190 } 191 }; 192 193 /// @see LinalgDetensorize in Linalg/Passes.td for more details. 194 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> { 195 LinalgDetensorize() = default; 196 LinalgDetensorize(const LinalgDetensorize &pass) 197 : LinalgDetensorizeBase<LinalgDetensorize>() {} 198 199 class CostModel { 200 public: 201 virtual ~CostModel() = default; 202 203 /// A cost model algorithm computes the following outputs: 204 /// 205 /// - opsToDetensor: the list of linalg ops that should be 206 /// detensored. 207 /// 208 /// - blockArgsToDetensor: since the operands and results of detensored 209 /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come 210 /// from a BB argument and a linalg op's output can be passed to successor 211 /// BBs), we need to maintain the sub-set of arguments that should be 212 /// detensored (i.e. converted by typeConverter) for each affected BB. 213 /// 214 /// Example: 215 /// 216 /// For the following snippet: 217 /// ... 218 /// ^bb1(%6: tensor<i32>, %9: tensor<i32>): 219 /// %7 = linalg.init_tensor [] : tensor<i32> 220 /// %8 = linalg.generic #attrs 221 /// ins(%6, %6 : tensor<i32>, tensor<i32>) 222 /// outs(%7 : tensor<i32>) { 223 /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): 224 /// %9 = arith.addi %arg0, %arg1 : i32 225 /// linalg.yield %9 : i32 226 /// } -> tensor<i32> 227 /// %10 = "some.op"(%9) 228 /// br ^bb2(%8 : tensor<i32>) 229 /// ... 230 /// 231 /// if the cost model decides that the linalg.generic op should be 232 /// detensored, then: 233 /// - opsToDetensor should be = {linalg.generic{add}}. 234 /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}. 235 virtual void compute(FuncOp func, DetensorizeTypeConverter typeConverter, 236 DenseSet<Operation *> &opsToDetensor, 237 DenseSet<BlockArgument> &blockArgsToDetensor) = 0; 238 239 /// From the blockArgsToDetensor set computed by a CostModel 240 /// implementation, this method computes the corresponding branch op 241 /// detensoring. The result is a map from a branch op to a subset of indices 242 /// of its operands. The indices specify which of the branch op's operands 243 /// should be detensored. 244 /// 245 /// For the previous example, this method would compute: {bb2 -> {0}}. 246 static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring( 247 const DenseSet<BlockArgument> &blockArgsToDetensor) { 248 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 249 250 for (auto blockArgumentElem : blockArgsToDetensor) { 251 Block *block = blockArgumentElem.getOwner(); 252 253 for (PredecessorIterator pred = block->pred_begin(); 254 pred != block->pred_end(); ++pred) { 255 BranchOpInterface terminator = 256 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 257 auto blockOperands = 258 terminator.getSuccessorOperands(pred.getSuccessorIndex()); 259 260 if (!blockOperands || blockOperands->empty()) 261 continue; 262 263 detensorableBranchOps[terminator].insert( 264 blockOperands->getBeginOperandIndex() + 265 blockArgumentElem.getArgNumber()); 266 } 267 } 268 269 return detensorableBranchOps; 270 } 271 }; 272 273 /// Detensorize linalg ops involved in control-flow within a function. 274 /// 275 /// This model starts from BranchOps and CondBranchOps within a function. For 276 /// each such branch, the model then walks the use-def chain for the branch's 277 /// condition backwards in order to understand where the condition's value 278 /// comes from. If the condition value is (indirectly) computed by a linalg op 279 /// that can be detensored, the model then continues walking the use-def chain 280 /// in order to understand where the linalg op's operands come from. This 281 /// leads to discovering a "detensoring component". A detensoring component is 282 /// the set of operations + block arguments that are involved in control-flow 283 /// AND can be detensored. 284 class ControlFlowDetectionModel : public CostModel { 285 public: 286 void compute(FuncOp func, DetensorizeTypeConverter typeConverter, 287 DenseSet<Operation *> &opsToDetensor, 288 DenseSet<BlockArgument> &blockArgsToDetensor) override { 289 SmallVector<Value> workList; 290 291 func.walk([&](CondBranchOp condBr) { 292 for (auto operand : condBr.getOperands()) { 293 workList.push_back(operand); 294 } 295 }); 296 297 func.walk([&](BranchOp br) { 298 for (auto operand : br.getOperands()) { 299 workList.push_back(operand); 300 } 301 }); 302 303 DenseSet<Value> visitedValues; 304 DenseSet<Operation *> visitedOps; 305 306 // For a (to-be-detesored) value, check if it "escapes" the block by being 307 // passed to terminator. If it does, then workList is updated with the 308 // corresponding argument to the successor block. 309 auto updateWorkListWithSuccessorArguments = 310 [&](Value value, BranchOpInterface terminator) { 311 if (!terminator) 312 return; 313 314 for (auto operandIdx : 315 llvm::seq<unsigned>(0, terminator->getOperands().size())) { 316 Value operand = terminator->getOperand(operandIdx); 317 318 if (operand == value) { 319 auto succBlockArg = 320 terminator.getSuccessorBlockArgument(operandIdx); 321 322 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg)) 323 workList.push_back(*succBlockArg); 324 } 325 } 326 }; 327 328 while (!workList.empty()) { 329 Value currentItem = workList.pop_back_val(); 330 331 if (!visitedValues.insert(currentItem).second) 332 continue; 333 334 // 1 - Look forward: 335 // 1.1 - If currentItem escapes to one or more successors, add 336 // the corresponding successor arguments to workList. 337 updateWorkListWithSuccessorArguments( 338 currentItem, dyn_cast<BranchOpInterface>( 339 currentItem.getParentBlock()->getTerminator())); 340 341 // 1.2 - For each user of currentItem, add the defined values to 342 // workList. This way, the user ops can be inspected later if they are 343 // detensorable and if so, their operands will be added to workList to 344 // potentially discover other parts of the detensorable component. 345 for (auto *user : currentItem.getUsers()) 346 for (Value result : user->getResults()) 347 workList.push_back(result); 348 349 // 2 - Look backward: 350 // 2.1 - The current item is defined by a block argument. If the owner 351 // block is a non-entry one, then: 352 // * Add the argument to blockArgsToDetensor. 353 // * Walk the use-def chain backwards to add each predecessor's 354 // terminator-operands corresponding to currentItem to workList. 355 if (currentItem.dyn_cast<BlockArgument>()) { 356 BlockArgument currentItemBlockArgument = 357 currentItem.cast<BlockArgument>(); 358 Block *ownerBlock = currentItemBlockArgument.getOwner(); 359 360 // Function arguments are not detensored/converted. 361 if (&*ownerBlock->getParent()->begin() == ownerBlock) 362 continue; 363 364 // This inner-block argument is involved in control-flow, it should be 365 // detensored. 366 blockArgsToDetensor.insert(currentItemBlockArgument); 367 368 for (PredecessorIterator pred = ownerBlock->pred_begin(); 369 pred != ownerBlock->pred_end(); ++pred) { 370 BranchOpInterface predTerminator = 371 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 372 373 // TODO: For now, we give up if any of the control-flow components 374 // in a function is not detensorable. Fix that. 375 if (!predTerminator) { 376 opsToDetensor.clear(); 377 blockArgsToDetensor.clear(); 378 return; 379 } 380 381 auto ownerBlockOperands = 382 predTerminator.getSuccessorOperands(pred.getSuccessorIndex()); 383 384 if (!ownerBlockOperands || ownerBlockOperands->empty()) 385 continue; 386 387 // For each predecessor, add the value it passes to that argument to 388 // workList to find out how it's computed. 389 workList.push_back( 390 ownerBlockOperands 391 .getValue()[currentItemBlockArgument.getArgNumber()]); 392 } 393 394 continue; 395 } 396 397 Operation *currentItemDefiningOp = currentItem.getDefiningOp(); 398 399 if (!visitedOps.insert(currentItemDefiningOp).second) 400 continue; 401 402 // 2.2 - The current item is computed by a GenericOp. If the op should 403 // be detensored, then: 404 // * Add it to opsToDetensor. 405 // * Add its operands to workList to discover other parts of the 406 // potentially detensorable component. 407 if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) { 408 // The op was encountered already, no need to inspect it again. 409 if (opsToDetensor.count(genericOp)) 410 continue; 411 412 // The op should not be detensored, give up on it but continue with 413 // discovering the rest of the control-flow component. 414 if (!shouldBeDetensored(genericOp, typeConverter)) { 415 continue; 416 } 417 418 opsToDetensor.insert(genericOp); 419 420 for (Value genericOpOperand : genericOp.inputs()) 421 workList.push_back(genericOpOperand); 422 423 continue; 424 } 425 426 // 2.3 - The current item is the result of a FromElementsOp, it will be 427 // trivially detensored later as part of canonicalization patterns 428 // applied at the end of detensoring. 429 // 430 // Note: No need to check whether the result type of this op is 431 // detensorable since if it wasn't we wouldn't reach that point in the 432 // work list. 433 if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp)) 434 continue; 435 436 // 2.4 - The current item is the result of a scalar op, add all its 437 // operands to the work list. 438 if (llvm::all_of( 439 currentItemDefiningOp->getResultTypes(), 440 [&](Type resultType) { return resultType.isIntOrFloat(); })) 441 for (Value scalarOpOperand : currentItemDefiningOp->getOperands()) 442 workList.push_back(scalarOpOperand); 443 } 444 445 // Since the cost model gives up on some ops (see the details of step 2.2 446 // above), block arguments that correspond to the values produced by those 447 // ops should not be detensored as well. 448 449 DenseSet<BlockArgument> blockArgsToRemove; 450 451 for (auto &blockArg : blockArgsToDetensor) { 452 Block *block = blockArg.getParentBlock(); 453 454 // For the potentially detensorable block argument, find the 455 // correpsonding operands in predecessor blocks. 456 for (PredecessorIterator pred = block->pred_begin(); 457 pred != block->pred_end(); ++pred) { 458 BranchOpInterface terminator = 459 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 460 auto blockOperands = 461 terminator.getSuccessorOperands(pred.getSuccessorIndex()); 462 463 if (!blockOperands || blockOperands->empty()) 464 continue; 465 466 Operation *definingOp = 467 terminator 468 ->getOperand(blockOperands->getBeginOperandIndex() + 469 blockArg.getArgNumber()) 470 .getDefiningOp(); 471 472 // If the operand is defined by a GenericOp that will not be 473 // detensored, then do not detensor the corresponding block argument. 474 if (dyn_cast_or_null<GenericOp>(definingOp) && 475 opsToDetensor.count(definingOp) == 0) { 476 blockArgsToRemove.insert(blockArg); 477 break; 478 } 479 } 480 } 481 482 for (auto &blockArg : blockArgsToRemove) { 483 blockArgsToDetensor.erase(blockArg); 484 } 485 } 486 }; 487 488 /// Detensorize everything that can detensored. 489 class AggressiveDetensoringModel : public CostModel { 490 public: 491 void compute(FuncOp func, DetensorizeTypeConverter typeConverter, 492 DenseSet<Operation *> &opsToDetensor, 493 DenseSet<BlockArgument> &blockArgsToDetensor) override { 494 func.walk([&](GenericOp genericOp) { 495 if (shouldBeDetensored(genericOp, typeConverter)) 496 opsToDetensor.insert(genericOp); 497 }); 498 499 for (Block &block : llvm::drop_begin(func.getBody(), 1)) 500 for (BlockArgument blockArgument : block.getArguments()) 501 blockArgsToDetensor.insert(blockArgument); 502 } 503 }; 504 505 void runOnFunction() override { 506 MLIRContext *context = &getContext(); 507 DetensorizeTypeConverter typeConverter; 508 RewritePatternSet patterns(context); 509 ConversionTarget target(*context); 510 DenseSet<Operation *> opsToDetensor; 511 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 512 DenseSet<BlockArgument> blockArgsToDetensor; 513 514 if (aggressiveMode.getValue()) { 515 AggressiveDetensoringModel costModel; 516 costModel.compute(getFunction(), typeConverter, opsToDetensor, 517 blockArgsToDetensor); 518 519 } else { 520 ControlFlowDetectionModel costModel; 521 costModel.compute(getFunction(), typeConverter, opsToDetensor, 522 blockArgsToDetensor); 523 } 524 525 detensorableBranchOps = 526 CostModel::computeBranchOpDetensoring(blockArgsToDetensor); 527 528 target.addDynamicallyLegalOp<GenericOp>( 529 [&](GenericOp op) { return !opsToDetensor.count(op); }); 530 531 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 532 // A function is legal if all of its non-entry blocks are legal. We 533 // don't legalize the entry block (i.e. the function's signature) 534 // since detensoring can't happen along external calling convention 535 // boundaries, which we conservatively approximate as all function 536 // signatures. 537 return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) { 538 if (llvm::any_of(blockArgsToDetensor, [&](BlockArgument blockArgument) { 539 return blockArgument.getOwner() == &block && 540 !typeConverter.isLegal(blockArgument.getType()); 541 })) { 542 return false; 543 } 544 return true; 545 }); 546 }); 547 548 target.markUnknownOpDynamicallyLegal([&](Operation *op) { 549 if (isNotBranchOpInterfaceOrReturnLikeOp(op) || 550 isLegalForReturnOpTypeConversionPattern(op, typeConverter, 551 /*returnOpAlwaysLegal*/ true)) 552 return true; 553 554 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { 555 if (!detensorableBranchOps.count(branchOp)) 556 return true; 557 558 for (auto operandIdx : detensorableBranchOps[branchOp]) 559 if (!typeConverter.isLegal( 560 branchOp->getOperand(operandIdx).getType())) 561 return false; 562 563 return true; 564 } 565 566 return false; 567 }); 568 569 patterns.insert<DetensorizeGenericOp>(typeConverter, context); 570 patterns.insert<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(), 571 context, typeConverter, 572 blockArgsToDetensor); 573 // Since non-entry block arguments get detensorized, we also need to 574 // update the control flow inside the function to reflect the correct 575 // types. 576 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp, 577 int operandIdx) -> bool { 578 return detensorableBranchOps.count(branchOp) && 579 detensorableBranchOps[branchOp].count(operandIdx); 580 }; 581 582 populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, 583 shouldConvertBranchOperand); 584 585 if (failed(applyFullConversion(getFunction(), target, std::move(patterns)))) 586 signalPassFailure(); 587 588 RewritePatternSet canonPatterns(context); 589 canonPatterns.add<ExtractFromReshapeFromElements>(context); 590 if (failed(applyPatternsAndFoldGreedily(getFunction(), 591 std::move(canonPatterns)))) 592 signalPassFailure(); 593 } 594 595 Option<bool> aggressiveMode{ 596 *this, "aggressive-mode", 597 llvm::cl::desc("Detensorize all ops that qualify for detensoring along " 598 "with branch operands and basic-block arguments.")}; 599 }; 600 } // namespace 601 602 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() { 603 return std::make_unique<LinalgDetensorize>(); 604 } 605