1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Linalg/IR/Linalg.h" 11 #include "mlir/Dialect/Linalg/Passes.h" 12 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/OpDefinition.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 #include <iterator> 18 #include <memory> 19 20 using namespace mlir; 21 using namespace mlir::linalg; 22 23 static Value sourceMaterializationCallback(OpBuilder &builder, Type type, 24 ValueRange inputs, Location loc) { 25 assert(inputs.size() == 1); 26 if (inputs[0].getType().isa<TensorType>()) 27 return nullptr; 28 29 // A detensored value is converted back by creating a new tensor from its 30 // element(s). 31 auto createNewTensorOp = 32 builder.create<tensor::FromElementsOp>(loc, inputs[0]); 33 34 // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to 35 // a tensor<dtype> instead. 36 return builder.create<tensor::CollapseShapeOp>( 37 loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{}); 38 } 39 40 namespace { 41 /// Defines the criteria a TensorType must follow in order to be considered 42 /// "detensorable". 43 /// 44 /// NOTE: For now, only 0-D tensors are supported. 45 /// 46 /// Returns true if tensorType can be detensored. 47 bool canBeDetensored(TensorType tensorType) { 48 return tensorType.hasRank() && tensorType.getRank() == 0; 49 } 50 51 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { 52 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op); 53 return genericOp && 54 llvm::all_of( 55 genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) { 56 return !typeConverter.isLegal(opOperand->get().getType()); 57 }); 58 } 59 60 /// A conversion patttern for detensoring `linalg.generic` ops. 61 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> { 62 public: 63 using OpConversionPattern::OpConversionPattern; 64 LogicalResult 65 matchAndRewrite(GenericOp op, OpAdaptor adaptor, 66 ConversionPatternRewriter &rewriter) const override { 67 Block *originalBlock = op->getBlock(); 68 69 // Gather some information about the op before inling its region. 70 Block *opEntryBlock = &*op.region().begin(); 71 YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator()); 72 73 // Split the op's region before the op. This way, we have a clear insertion 74 // point in which the op can be inlined. 75 Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op)); 76 rewriter.inlineRegionBefore(op.region(), newBlock); 77 // Now that op's region is inlined, the operands of its YieldOp are mapped 78 // to the materialized target values. Therefore, we can replace the op's 79 // uses with those of its YielOp's operands. 80 rewriter.replaceOp(op, yieldOp->getOperands()); 81 82 // No need for these intermediate blocks, merge them into 1. 83 rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands()); 84 rewriter.mergeBlocks(newBlock, originalBlock, {}); 85 86 rewriter.eraseOp(&*Block::iterator(yieldOp)); 87 88 return success(); 89 } 90 }; 91 92 /// A conversion pattern for detensoring internal (non-entry) blocks within a 93 /// function. 94 struct FunctionNonEntryBlockConversion : public ConversionPattern { 95 FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter, 96 DenseSet<BlockArgument> blockArgsToDetensor) 97 : ConversionPattern(converter, MatchTraitOpTypeTag(), 98 TypeID::get<OpTrait::FunctionLike>(), /*benefit=*/1, 99 ctx), 100 blockArgsToDetensor(blockArgsToDetensor) {} 101 102 LogicalResult 103 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 104 ConversionPatternRewriter &rewriter) const override { 105 rewriter.startRootUpdate(op); 106 Region ®ion = function_like_impl::getFunctionBody(op); 107 SmallVector<TypeConverter::SignatureConversion, 2> conversions; 108 109 for (Block &block : llvm::drop_begin(region, 1)) { 110 conversions.emplace_back(block.getNumArguments()); 111 TypeConverter::SignatureConversion &back = conversions.back(); 112 113 for (BlockArgument blockArgument : block.getArguments()) { 114 int idx = blockArgument.getArgNumber(); 115 116 if (blockArgsToDetensor.count(blockArgument)) 117 back.addInputs(idx, {getTypeConverter()->convertType( 118 block.getArgumentTypes()[idx])}); 119 else 120 back.addInputs(idx, {block.getArgumentTypes()[idx]}); 121 } 122 } 123 124 if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, 125 conversions))) { 126 rewriter.cancelRootUpdate(op); 127 return failure(); 128 } 129 130 rewriter.finalizeRootUpdate(op); 131 return success(); 132 } 133 134 private: 135 const DenseSet<BlockArgument> blockArgsToDetensor; 136 }; 137 138 class DetensorizeTypeConverter : public TypeConverter { 139 public: 140 DetensorizeTypeConverter() { 141 addConversion([](Type type) { return type; }); 142 143 // A TensorType that can be detensored, is converted to the underlying 144 // element type. 145 addConversion([](TensorType tensorType) -> Type { 146 if (canBeDetensored(tensorType)) 147 return tensorType.getElementType(); 148 149 return tensorType; 150 }); 151 152 // A tensor value is detensoried by extracting its element(s). 153 addTargetMaterialization([](OpBuilder &builder, Type type, 154 ValueRange inputs, Location loc) -> Value { 155 return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{}); 156 }); 157 158 addSourceMaterialization(sourceMaterializationCallback); 159 addArgumentMaterialization(sourceMaterializationCallback); 160 } 161 }; 162 163 /// Canonicalizes the pattern of the form 164 /// 165 /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> 166 /// %reshaped_tensor = tensor.collapse_shape %tensor [] 167 /// : tensor<1xi32> into tensor<i32> 168 /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32> 169 /// 170 /// to just %element. 171 struct ExtractFromReshapeFromElements 172 : public OpRewritePattern<tensor::ExtractOp> { 173 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 174 175 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 176 PatternRewriter &rewriter) const final { 177 if (!extract.indices().empty()) 178 return failure(); 179 180 auto tensorReshape = 181 extract.tensor().getDefiningOp<tensor::CollapseShapeOp>(); 182 if (tensorReshape == nullptr) 183 return failure(); 184 185 auto tensorFromElements = 186 tensorReshape.getOperand() 187 .getDefiningOp<mlir::tensor::FromElementsOp>(); 188 if (tensorFromElements == nullptr) 189 return failure(); 190 191 rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); 192 return success(); 193 } 194 }; 195 196 /// @see LinalgDetensorize in Linalg/Passes.td for more details. 197 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> { 198 LinalgDetensorize() = default; 199 LinalgDetensorize(const LinalgDetensorize &pass) 200 : LinalgDetensorizeBase<LinalgDetensorize>() {} 201 202 class CostModel { 203 public: 204 virtual ~CostModel() = default; 205 206 /// A cost model algorithm computes the following outputs: 207 /// 208 /// - opsToDetensor: the list of linalg ops that should be 209 /// detensored. 210 /// 211 /// - blockArgsToDetensor: since the operands and results of detensored 212 /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come 213 /// from a BB argument and a linalg op's output can be passed to successor 214 /// BBs), we need to maintain the sub-set of arguments that should be 215 /// detensored (i.e. converted by typeConverter) for each affected BB. 216 /// 217 /// Example: 218 /// 219 /// For the following snippet: 220 /// ... 221 /// ^bb1(%6: tensor<i32>, %9: tensor<i32>): 222 /// %7 = linalg.init_tensor [] : tensor<i32> 223 /// %8 = linalg.generic #attrs 224 /// ins(%6, %6 : tensor<i32>, tensor<i32>) 225 /// outs(%7 : tensor<i32>) { 226 /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): 227 /// %9 = arith.addi %arg0, %arg1 : i32 228 /// linalg.yield %9 : i32 229 /// } -> tensor<i32> 230 /// %10 = "some.op"(%9) 231 /// br ^bb2(%8 : tensor<i32>) 232 /// ... 233 /// 234 /// if the cost model decides that the linalg.generic op should be 235 /// detensored, then: 236 /// - opsToDetensor should be = {linalg.generic{add}}. 237 /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}. 238 virtual void compute(Operation *func, 239 DetensorizeTypeConverter typeConverter, 240 DenseSet<Operation *> &opsToDetensor, 241 DenseSet<BlockArgument> &blockArgsToDetensor) = 0; 242 243 /// From the blockArgsToDetensor set computed by a CostModel 244 /// implementation, this method computes the corresponding branch op 245 /// detensoring. The result is a map from a branch op to a subset of indices 246 /// of its operands. The indices specify which of the branch op's operands 247 /// should be detensored. 248 /// 249 /// For the previous example, this method would compute: {bb2 -> {0}}. 250 static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring( 251 const DenseSet<BlockArgument> &blockArgsToDetensor) { 252 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 253 254 for (auto blockArgumentElem : blockArgsToDetensor) { 255 Block *block = blockArgumentElem.getOwner(); 256 257 for (PredecessorIterator pred = block->pred_begin(); 258 pred != block->pred_end(); ++pred) { 259 BranchOpInterface terminator = 260 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 261 auto blockOperands = 262 terminator.getSuccessorOperands(pred.getSuccessorIndex()); 263 264 if (!blockOperands || blockOperands->empty()) 265 continue; 266 267 detensorableBranchOps[terminator].insert( 268 blockOperands->getBeginOperandIndex() + 269 blockArgumentElem.getArgNumber()); 270 } 271 } 272 273 return detensorableBranchOps; 274 } 275 }; 276 277 /// Detensorize linalg ops involved in control-flow within a function. 278 /// 279 /// This model starts from BranchOps and CondBranchOps within a function. For 280 /// each such branch, the model then walks the use-def chain for the branch's 281 /// condition backwards in order to understand where the condition's value 282 /// comes from. If the condition value is (indirectly) computed by a linalg op 283 /// that can be detensored, the model then continues walking the use-def chain 284 /// in order to understand where the linalg op's operands come from. This 285 /// leads to discovering a "detensoring component". A detensoring component is 286 /// the set of operations + block arguments that are involved in control-flow 287 /// AND can be detensored. 288 class ControlFlowDetectionModel : public CostModel { 289 public: 290 void compute(Operation *func, DetensorizeTypeConverter typeConverter, 291 DenseSet<Operation *> &opsToDetensor, 292 DenseSet<BlockArgument> &blockArgsToDetensor) override { 293 SmallVector<Value> workList; 294 295 func->walk([&](CondBranchOp condBr) { 296 for (auto operand : condBr.getOperands()) { 297 workList.push_back(operand); 298 } 299 }); 300 301 func->walk([&](BranchOp br) { 302 for (auto operand : br.getOperands()) { 303 workList.push_back(operand); 304 } 305 }); 306 307 DenseSet<Value> visitedValues; 308 DenseSet<Operation *> visitedOps; 309 310 // For a (to-be-detesored) value, check if it "escapes" the block by being 311 // passed to terminator. If it does, then workList is updated with the 312 // corresponding argument to the successor block. 313 auto updateWorkListWithSuccessorArguments = 314 [&](Value value, BranchOpInterface terminator) { 315 if (!terminator) 316 return; 317 318 for (auto operandIdx : 319 llvm::seq<unsigned>(0, terminator->getOperands().size())) { 320 Value operand = terminator->getOperand(operandIdx); 321 322 if (operand == value) { 323 auto succBlockArg = 324 terminator.getSuccessorBlockArgument(operandIdx); 325 326 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg)) 327 workList.push_back(*succBlockArg); 328 } 329 } 330 }; 331 332 while (!workList.empty()) { 333 Value currentItem = workList.pop_back_val(); 334 335 if (!visitedValues.insert(currentItem).second) 336 continue; 337 338 // 1 - Look forward: 339 // 1.1 - If currentItem escapes to one or more successors, add 340 // the corresponding successor arguments to workList. 341 updateWorkListWithSuccessorArguments( 342 currentItem, dyn_cast<BranchOpInterface>( 343 currentItem.getParentBlock()->getTerminator())); 344 345 // 1.2 - For each user of currentItem, add the defined values to 346 // workList. This way, the user ops can be inspected later if they are 347 // detensorable and if so, their operands will be added to workList to 348 // potentially discover other parts of the detensorable component. 349 for (auto *user : currentItem.getUsers()) 350 for (Value result : user->getResults()) 351 workList.push_back(result); 352 353 // 2 - Look backward: 354 // 2.1 - The current item is defined by a block argument. If the owner 355 // block is a non-entry one, then: 356 // * Add the argument to blockArgsToDetensor. 357 // * Walk the use-def chain backwards to add each predecessor's 358 // terminator-operands corresponding to currentItem to workList. 359 if (currentItem.dyn_cast<BlockArgument>()) { 360 BlockArgument currentItemBlockArgument = 361 currentItem.cast<BlockArgument>(); 362 Block *ownerBlock = currentItemBlockArgument.getOwner(); 363 364 // Function arguments are not detensored/converted. 365 if (&*ownerBlock->getParent()->begin() == ownerBlock) 366 continue; 367 368 // This inner-block argument is involved in control-flow, it should be 369 // detensored. 370 blockArgsToDetensor.insert(currentItemBlockArgument); 371 372 for (PredecessorIterator pred = ownerBlock->pred_begin(); 373 pred != ownerBlock->pred_end(); ++pred) { 374 BranchOpInterface predTerminator = 375 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 376 377 // TODO: For now, we give up if any of the control-flow components 378 // in a function is not detensorable. Fix that. 379 if (!predTerminator) { 380 opsToDetensor.clear(); 381 blockArgsToDetensor.clear(); 382 return; 383 } 384 385 auto ownerBlockOperands = 386 predTerminator.getSuccessorOperands(pred.getSuccessorIndex()); 387 388 if (!ownerBlockOperands || ownerBlockOperands->empty()) 389 continue; 390 391 // For each predecessor, add the value it passes to that argument to 392 // workList to find out how it's computed. 393 workList.push_back( 394 ownerBlockOperands 395 .getValue()[currentItemBlockArgument.getArgNumber()]); 396 } 397 398 continue; 399 } 400 401 Operation *currentItemDefiningOp = currentItem.getDefiningOp(); 402 403 if (!visitedOps.insert(currentItemDefiningOp).second) 404 continue; 405 406 // 2.2 - The current item is computed by a GenericOp. If the op should 407 // be detensored, then: 408 // * Add it to opsToDetensor. 409 // * Add its operands to workList to discover other parts of the 410 // potentially detensorable component. 411 if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) { 412 // The op was encountered already, no need to inspect it again. 413 if (opsToDetensor.count(genericOp)) 414 continue; 415 416 // The op should not be detensored, give up on it but continue with 417 // discovering the rest of the control-flow component. 418 if (!shouldBeDetensored(genericOp, typeConverter)) { 419 continue; 420 } 421 422 opsToDetensor.insert(genericOp); 423 424 for (Value genericOpOperand : genericOp.inputs()) 425 workList.push_back(genericOpOperand); 426 427 continue; 428 } 429 430 // 2.3 - The current item is the result of a FromElementsOp, it will be 431 // trivially detensored later as part of canonicalization patterns 432 // applied at the end of detensoring. 433 // 434 // Note: No need to check whether the result type of this op is 435 // detensorable since if it wasn't we wouldn't reach that point in the 436 // work list. 437 if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp)) 438 continue; 439 440 // 2.4 - The current item is the result of a scalar op, add all its 441 // operands to the work list. 442 if (llvm::all_of( 443 currentItemDefiningOp->getResultTypes(), 444 [&](Type resultType) { return resultType.isIntOrFloat(); })) 445 for (Value scalarOpOperand : currentItemDefiningOp->getOperands()) 446 workList.push_back(scalarOpOperand); 447 } 448 449 // Since the cost model gives up on some ops (see the details of step 2.2 450 // above), block arguments that correspond to the values produced by those 451 // ops should not be detensored as well. 452 453 DenseSet<BlockArgument> blockArgsToRemove; 454 455 for (auto &blockArg : blockArgsToDetensor) { 456 Block *block = blockArg.getParentBlock(); 457 458 // For the potentially detensorable block argument, find the 459 // correpsonding operands in predecessor blocks. 460 for (PredecessorIterator pred = block->pred_begin(); 461 pred != block->pred_end(); ++pred) { 462 BranchOpInterface terminator = 463 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 464 auto blockOperands = 465 terminator.getSuccessorOperands(pred.getSuccessorIndex()); 466 467 if (!blockOperands || blockOperands->empty()) 468 continue; 469 470 Operation *definingOp = 471 terminator 472 ->getOperand(blockOperands->getBeginOperandIndex() + 473 blockArg.getArgNumber()) 474 .getDefiningOp(); 475 476 // If the operand is defined by a GenericOp that will not be 477 // detensored, then do not detensor the corresponding block argument. 478 if (dyn_cast_or_null<GenericOp>(definingOp) && 479 opsToDetensor.count(definingOp) == 0) { 480 blockArgsToRemove.insert(blockArg); 481 break; 482 } 483 } 484 } 485 486 for (auto &blockArg : blockArgsToRemove) { 487 blockArgsToDetensor.erase(blockArg); 488 } 489 } 490 }; 491 492 /// Detensorize everything that can detensored. 493 class AggressiveDetensoringModel : public CostModel { 494 public: 495 void compute(Operation *func, DetensorizeTypeConverter typeConverter, 496 DenseSet<Operation *> &opsToDetensor, 497 DenseSet<BlockArgument> &blockArgsToDetensor) override { 498 func->walk([&](GenericOp genericOp) { 499 if (shouldBeDetensored(genericOp, typeConverter)) 500 opsToDetensor.insert(genericOp); 501 }); 502 503 for (Block &block : 504 llvm::drop_begin(function_like_impl::getFunctionBody(func), 1)) 505 for (BlockArgument blockArgument : block.getArguments()) 506 blockArgsToDetensor.insert(blockArgument); 507 } 508 }; 509 510 void runOnOperation() override { 511 assert(getOperation()->hasTrait<OpTrait::FunctionLike>() && 512 "DetensorizePass can only be run on FunctionLike operations"); 513 MLIRContext *context = &getContext(); 514 DetensorizeTypeConverter typeConverter; 515 RewritePatternSet patterns(context); 516 ConversionTarget target(*context); 517 DenseSet<Operation *> opsToDetensor; 518 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 519 DenseSet<BlockArgument> blockArgsToDetensor; 520 521 if (aggressiveMode.getValue()) { 522 AggressiveDetensoringModel costModel; 523 costModel.compute(getOperation(), typeConverter, opsToDetensor, 524 blockArgsToDetensor); 525 526 } else { 527 ControlFlowDetectionModel costModel; 528 costModel.compute(getOperation(), typeConverter, opsToDetensor, 529 blockArgsToDetensor); 530 } 531 532 detensorableBranchOps = 533 CostModel::computeBranchOpDetensoring(blockArgsToDetensor); 534 535 target.addDynamicallyLegalOp<GenericOp>( 536 [&](GenericOp op) { return !opsToDetensor.count(op); }); 537 538 target.markUnknownOpDynamicallyLegal([&](Operation *op) { 539 // A function is legal if all of its non-entry blocks are legal. We 540 // don't legalize the entry block (i.e. the function's signature) 541 // since detensoring can't happen along external calling convention 542 // boundaries, which we conservatively approximate as all function 543 // signatures. 544 if (op->hasTrait<OpTrait::FunctionLike>()) { 545 auto &body = function_like_impl::getFunctionBody(op); 546 return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) { 547 if (llvm::any_of( 548 blockArgsToDetensor, [&](BlockArgument blockArgument) { 549 return blockArgument.getOwner() == &block && 550 !typeConverter.isLegal(blockArgument.getType()); 551 })) { 552 return false; 553 } 554 return true; 555 }); 556 } 557 558 if (isNotBranchOpInterfaceOrReturnLikeOp(op) || 559 isLegalForReturnOpTypeConversionPattern(op, typeConverter, 560 /*returnOpAlwaysLegal*/ true)) 561 return true; 562 563 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { 564 if (!detensorableBranchOps.count(branchOp)) 565 return true; 566 567 for (auto operandIdx : detensorableBranchOps[branchOp]) 568 if (!typeConverter.isLegal( 569 branchOp->getOperand(operandIdx).getType())) 570 return false; 571 572 return true; 573 } 574 575 return false; 576 }); 577 578 patterns.insert<DetensorizeGenericOp>(typeConverter, context); 579 patterns.insert<FunctionNonEntryBlockConversion>(context, typeConverter, 580 blockArgsToDetensor); 581 // Since non-entry block arguments get detensorized, we also need to 582 // update the control flow inside the function to reflect the correct 583 // types. 584 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp, 585 int operandIdx) -> bool { 586 return detensorableBranchOps.count(branchOp) && 587 detensorableBranchOps[branchOp].count(operandIdx); 588 }; 589 590 populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, 591 shouldConvertBranchOperand); 592 593 if (failed( 594 applyFullConversion(getOperation(), target, std::move(patterns)))) 595 signalPassFailure(); 596 597 RewritePatternSet canonPatterns(context); 598 canonPatterns.add<ExtractFromReshapeFromElements>(context); 599 if (failed(applyPatternsAndFoldGreedily(getOperation(), 600 std::move(canonPatterns)))) 601 signalPassFailure(); 602 } 603 604 Option<bool> aggressiveMode{ 605 *this, "aggressive-mode", 606 llvm::cl::desc("Detensorize all ops that qualify for detensoring along " 607 "with branch operands and basic-block arguments.")}; 608 }; 609 } // namespace 610 611 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() { 612 return std::make_unique<LinalgDetensorize>(); 613 } 614