1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/OpDefinition.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 #include <iterator> 19 #include <memory> 20 21 using namespace mlir; 22 using namespace mlir::linalg; 23 24 static Value sourceMaterializationCallback(OpBuilder &builder, Type type, 25 ValueRange inputs, Location loc) { 26 assert(inputs.size() == 1); 27 if (inputs[0].getType().isa<TensorType>()) 28 return nullptr; 29 30 // A detensored value is converted back by creating a new tensor from its 31 // element(s). 32 auto createNewTensorOp = builder.create<tensor::FromElementsOp>( 33 loc, inputs[0].getType(), inputs[0]); 34 35 // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to 36 // a tensor<dtype> instead. 37 return builder.create<tensor::CollapseShapeOp>( 38 loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{}); 39 } 40 41 namespace { 42 /// Defines the criteria a TensorType must follow in order to be considered 43 /// "detensorable". 44 /// 45 /// NOTE: For now, only 0-D tensors are supported. 46 /// 47 /// Returns true if tensorType can be detensored. 48 bool canBeDetensored(TensorType tensorType) { 49 return tensorType.hasRank() && tensorType.getRank() == 0; 50 } 51 52 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { 53 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op); 54 return genericOp && 55 llvm::all_of( 56 genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) { 57 return !typeConverter.isLegal(opOperand->get().getType()); 58 }); 59 } 60 61 /// A conversion patttern for detensoring `linalg.generic` ops. 62 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> { 63 public: 64 using OpConversionPattern::OpConversionPattern; 65 LogicalResult 66 matchAndRewrite(GenericOp op, OpAdaptor adaptor, 67 ConversionPatternRewriter &rewriter) const override { 68 Block *originalBlock = op->getBlock(); 69 70 // Gather some information about the op before inling its region. 71 Block *opEntryBlock = &*op.region().begin(); 72 YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator()); 73 74 // Split the op's region before the op. This way, we have a clear insertion 75 // point in which the op can be inlined. 76 Block *newBlock = originalBlock->splitBlock(op); 77 rewriter.inlineRegionBefore(op.region(), newBlock); 78 // Now that op's region is inlined, the operands of its YieldOp are mapped 79 // to the materialized target values. Therefore, we can replace the op's 80 // uses with those of its YielOp's operands. 81 rewriter.replaceOp(op, yieldOp->getOperands()); 82 83 // No need for these intermediate blocks, merge them into 1. 84 rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands()); 85 rewriter.mergeBlocks(newBlock, originalBlock, {}); 86 87 rewriter.eraseOp(&*Block::iterator(yieldOp)); 88 89 return success(); 90 } 91 }; 92 93 /// A conversion pattern for detensoring internal (non-entry) blocks within a 94 /// function. 95 struct FunctionNonEntryBlockConversion : public ConversionPattern { 96 FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter, 97 DenseSet<BlockArgument> blockArgsToDetensor) 98 : ConversionPattern(converter, MatchTraitOpTypeTag(), 99 TypeID::get<OpTrait::FunctionLike>(), /*benefit=*/1, 100 ctx), 101 blockArgsToDetensor(blockArgsToDetensor) {} 102 103 LogicalResult 104 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 105 ConversionPatternRewriter &rewriter) const override { 106 rewriter.startRootUpdate(op); 107 Region ®ion = function_like_impl::getFunctionBody(op); 108 SmallVector<TypeConverter::SignatureConversion, 2> conversions; 109 110 for (Block &block : llvm::drop_begin(region, 1)) { 111 conversions.emplace_back(block.getNumArguments()); 112 TypeConverter::SignatureConversion &back = conversions.back(); 113 114 for (BlockArgument blockArgument : block.getArguments()) { 115 int idx = blockArgument.getArgNumber(); 116 117 if (blockArgsToDetensor.count(blockArgument)) 118 back.addInputs(idx, {getTypeConverter()->convertType( 119 block.getArgumentTypes()[idx])}); 120 else 121 back.addInputs(idx, {block.getArgumentTypes()[idx]}); 122 } 123 } 124 125 if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, 126 conversions))) { 127 rewriter.cancelRootUpdate(op); 128 return failure(); 129 } 130 131 rewriter.finalizeRootUpdate(op); 132 return success(); 133 } 134 135 private: 136 const DenseSet<BlockArgument> blockArgsToDetensor; 137 }; 138 139 class DetensorizeTypeConverter : public TypeConverter { 140 public: 141 DetensorizeTypeConverter() { 142 addConversion([](Type type) { return type; }); 143 144 // A TensorType that can be detensored, is converted to the underlying 145 // element type. 146 addConversion([](TensorType tensorType) -> Type { 147 if (canBeDetensored(tensorType)) 148 return tensorType.getElementType(); 149 150 return tensorType; 151 }); 152 153 // A tensor value is detensoried by extracting its element(s). 154 addTargetMaterialization([](OpBuilder &builder, Type type, 155 ValueRange inputs, Location loc) -> Value { 156 return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{}); 157 }); 158 159 addSourceMaterialization(sourceMaterializationCallback); 160 addArgumentMaterialization(sourceMaterializationCallback); 161 } 162 }; 163 164 /// Canonicalizes the pattern of the form 165 /// 166 /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> 167 /// %reshaped_tensor = tensor.collapse_shape %tensor [] 168 /// : tensor<1xi32> into tensor<i32> 169 /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32> 170 /// 171 /// to just %element. 172 struct ExtractFromReshapeFromElements 173 : public OpRewritePattern<tensor::ExtractOp> { 174 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 175 176 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 177 PatternRewriter &rewriter) const final { 178 if (!extract.indices().empty()) 179 return failure(); 180 181 auto tensorReshape = 182 extract.tensor().getDefiningOp<tensor::CollapseShapeOp>(); 183 if (tensorReshape == nullptr) 184 return failure(); 185 186 auto tensorFromElements = 187 tensorReshape.getOperand() 188 .getDefiningOp<mlir::tensor::FromElementsOp>(); 189 if (tensorFromElements == nullptr) 190 return failure(); 191 192 rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); 193 return success(); 194 } 195 }; 196 197 /// @see LinalgDetensorize in Linalg/Passes.td for more details. 198 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> { 199 LinalgDetensorize() = default; 200 LinalgDetensorize(const LinalgDetensorize &pass) 201 : LinalgDetensorizeBase<LinalgDetensorize>() {} 202 203 class CostModel { 204 public: 205 virtual ~CostModel() = default; 206 207 /// A cost model algorithm computes the following outputs: 208 /// 209 /// - opsToDetensor: the list of linalg ops that should be 210 /// detensored. 211 /// 212 /// - blockArgsToDetensor: since the operands and results of detensored 213 /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come 214 /// from a BB argument and a linalg op's output can be passed to successor 215 /// BBs), we need to maintain the sub-set of arguments that should be 216 /// detensored (i.e. converted by typeConverter) for each affected BB. 217 /// 218 /// Example: 219 /// 220 /// For the following snippet: 221 /// ... 222 /// ^bb1(%6: tensor<i32>, %9: tensor<i32>): 223 /// %7 = linalg.init_tensor [] : tensor<i32> 224 /// %8 = linalg.generic #attrs 225 /// ins(%6, %6 : tensor<i32>, tensor<i32>) 226 /// outs(%7 : tensor<i32>) { 227 /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): 228 /// %9 = arith.addi %arg0, %arg1 : i32 229 /// linalg.yield %9 : i32 230 /// } -> tensor<i32> 231 /// %10 = "some.op"(%9) 232 /// br ^bb2(%8 : tensor<i32>) 233 /// ... 234 /// 235 /// if the cost model decides that the linalg.generic op should be 236 /// detensored, then: 237 /// - opsToDetensor should be = {linalg.generic{add}}. 238 /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}. 239 virtual void compute(Operation *func, 240 DetensorizeTypeConverter typeConverter, 241 DenseSet<Operation *> &opsToDetensor, 242 DenseSet<BlockArgument> &blockArgsToDetensor) = 0; 243 244 /// From the blockArgsToDetensor set computed by a CostModel 245 /// implementation, this method computes the corresponding branch op 246 /// detensoring. The result is a map from a branch op to a subset of indices 247 /// of its operands. The indices specify which of the branch op's operands 248 /// should be detensored. 249 /// 250 /// For the previous example, this method would compute: {bb2 -> {0}}. 251 static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring( 252 const DenseSet<BlockArgument> &blockArgsToDetensor) { 253 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 254 255 for (auto blockArgumentElem : blockArgsToDetensor) { 256 Block *block = blockArgumentElem.getOwner(); 257 258 for (PredecessorIterator pred = block->pred_begin(); 259 pred != block->pred_end(); ++pred) { 260 BranchOpInterface terminator = 261 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 262 auto blockOperands = 263 terminator.getSuccessorOperands(pred.getSuccessorIndex()); 264 265 if (!blockOperands || blockOperands->empty()) 266 continue; 267 268 detensorableBranchOps[terminator].insert( 269 blockOperands->getBeginOperandIndex() + 270 blockArgumentElem.getArgNumber()); 271 } 272 } 273 274 return detensorableBranchOps; 275 } 276 }; 277 278 /// Detensorize linalg ops involved in control-flow within a function. 279 /// 280 /// This model starts from BranchOps and CondBranchOps within a function. For 281 /// each such branch, the model then walks the use-def chain for the branch's 282 /// condition backwards in order to understand where the condition's value 283 /// comes from. If the condition value is (indirectly) computed by a linalg op 284 /// that can be detensored, the model then continues walking the use-def chain 285 /// in order to understand where the linalg op's operands come from. This 286 /// leads to discovering a "detensoring component". A detensoring component is 287 /// the set of operations + block arguments that are involved in control-flow 288 /// AND can be detensored. 289 class ControlFlowDetectionModel : public CostModel { 290 public: 291 void compute(Operation *func, DetensorizeTypeConverter typeConverter, 292 DenseSet<Operation *> &opsToDetensor, 293 DenseSet<BlockArgument> &blockArgsToDetensor) override { 294 SmallVector<Value> workList; 295 296 func->walk([&](CondBranchOp condBr) { 297 for (auto operand : condBr.getOperands()) { 298 workList.push_back(operand); 299 } 300 }); 301 302 func->walk([&](BranchOp br) { 303 for (auto operand : br.getOperands()) { 304 workList.push_back(operand); 305 } 306 }); 307 308 DenseSet<Value> visitedValues; 309 DenseSet<Operation *> visitedOps; 310 311 // For a (to-be-detesored) value, check if it "escapes" the block by being 312 // passed to terminator. If it does, then workList is updated with the 313 // corresponding argument to the successor block. 314 auto updateWorkListWithSuccessorArguments = 315 [&](Value value, BranchOpInterface terminator) { 316 if (!terminator) 317 return; 318 319 for (auto operandIdx : 320 llvm::seq<unsigned>(0, terminator->getOperands().size())) { 321 Value operand = terminator->getOperand(operandIdx); 322 323 if (operand == value) { 324 auto succBlockArg = 325 terminator.getSuccessorBlockArgument(operandIdx); 326 327 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg)) 328 workList.push_back(*succBlockArg); 329 } 330 } 331 }; 332 333 while (!workList.empty()) { 334 Value currentItem = workList.pop_back_val(); 335 336 if (!visitedValues.insert(currentItem).second) 337 continue; 338 339 // 1 - Look forward: 340 // 1.1 - If currentItem escapes to one or more successors, add 341 // the corresponding successor arguments to workList. 342 updateWorkListWithSuccessorArguments( 343 currentItem, dyn_cast<BranchOpInterface>( 344 currentItem.getParentBlock()->getTerminator())); 345 346 // 1.2 - For each user of currentItem, add the defined values to 347 // workList. This way, the user ops can be inspected later if they are 348 // detensorable and if so, their operands will be added to workList to 349 // potentially discover other parts of the detensorable component. 350 for (auto *user : currentItem.getUsers()) 351 for (Value result : user->getResults()) 352 workList.push_back(result); 353 354 // 2 - Look backward: 355 // 2.1 - The current item is defined by a block argument. If the owner 356 // block is a non-entry one, then: 357 // * Add the argument to blockArgsToDetensor. 358 // * Walk the use-def chain backwards to add each predecessor's 359 // terminator-operands corresponding to currentItem to workList. 360 if (currentItem.dyn_cast<BlockArgument>()) { 361 BlockArgument currentItemBlockArgument = 362 currentItem.cast<BlockArgument>(); 363 Block *ownerBlock = currentItemBlockArgument.getOwner(); 364 365 // Function arguments are not detensored/converted. 366 if (&*ownerBlock->getParent()->begin() == ownerBlock) 367 continue; 368 369 // This inner-block argument is involved in control-flow, it should be 370 // detensored. 371 blockArgsToDetensor.insert(currentItemBlockArgument); 372 373 for (PredecessorIterator pred = ownerBlock->pred_begin(); 374 pred != ownerBlock->pred_end(); ++pred) { 375 BranchOpInterface predTerminator = 376 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 377 378 // TODO: For now, we give up if any of the control-flow components 379 // in a function is not detensorable. Fix that. 380 if (!predTerminator) { 381 opsToDetensor.clear(); 382 blockArgsToDetensor.clear(); 383 return; 384 } 385 386 auto ownerBlockOperands = 387 predTerminator.getSuccessorOperands(pred.getSuccessorIndex()); 388 389 if (!ownerBlockOperands || ownerBlockOperands->empty()) 390 continue; 391 392 // For each predecessor, add the value it passes to that argument to 393 // workList to find out how it's computed. 394 workList.push_back( 395 ownerBlockOperands 396 .getValue()[currentItemBlockArgument.getArgNumber()]); 397 } 398 399 continue; 400 } 401 402 Operation *currentItemDefiningOp = currentItem.getDefiningOp(); 403 404 if (!visitedOps.insert(currentItemDefiningOp).second) 405 continue; 406 407 // 2.2 - The current item is computed by a GenericOp. If the op should 408 // be detensored, then: 409 // * Add it to opsToDetensor. 410 // * Add its operands to workList to discover other parts of the 411 // potentially detensorable component. 412 if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) { 413 // The op was encountered already, no need to inspect it again. 414 if (opsToDetensor.count(genericOp)) 415 continue; 416 417 // The op should not be detensored, give up on it but continue with 418 // discovering the rest of the control-flow component. 419 if (!shouldBeDetensored(genericOp, typeConverter)) { 420 continue; 421 } 422 423 opsToDetensor.insert(genericOp); 424 425 for (Value genericOpOperand : genericOp.inputs()) 426 workList.push_back(genericOpOperand); 427 428 continue; 429 } 430 431 // 2.3 - The current item is the result of a FromElementsOp, it will be 432 // trivially detensored later as part of canonicalization patterns 433 // applied at the end of detensoring. 434 // 435 // Note: No need to check whether the result type of this op is 436 // detensorable since if it wasn't we wouldn't reach that point in the 437 // work list. 438 if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp)) 439 continue; 440 441 // 2.4 - The current item is the result of a scalar op, add all its 442 // operands to the work list. 443 if (llvm::all_of( 444 currentItemDefiningOp->getResultTypes(), 445 [&](Type resultType) { return resultType.isIntOrFloat(); })) 446 for (Value scalarOpOperand : currentItemDefiningOp->getOperands()) 447 workList.push_back(scalarOpOperand); 448 } 449 450 // Since the cost model gives up on some ops (see the details of step 2.2 451 // above), block arguments that correspond to the values produced by those 452 // ops should not be detensored as well. 453 454 DenseSet<BlockArgument> blockArgsToRemove; 455 456 for (auto &blockArg : blockArgsToDetensor) { 457 Block *block = blockArg.getParentBlock(); 458 459 // For the potentially detensorable block argument, find the 460 // correpsonding operands in predecessor blocks. 461 for (PredecessorIterator pred = block->pred_begin(); 462 pred != block->pred_end(); ++pred) { 463 BranchOpInterface terminator = 464 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 465 auto blockOperands = 466 terminator.getSuccessorOperands(pred.getSuccessorIndex()); 467 468 if (!blockOperands || blockOperands->empty()) 469 continue; 470 471 Operation *definingOp = 472 terminator 473 ->getOperand(blockOperands->getBeginOperandIndex() + 474 blockArg.getArgNumber()) 475 .getDefiningOp(); 476 477 // If the operand is defined by a GenericOp that will not be 478 // detensored, then do not detensor the corresponding block argument. 479 if (dyn_cast_or_null<GenericOp>(definingOp) && 480 opsToDetensor.count(definingOp) == 0) { 481 blockArgsToRemove.insert(blockArg); 482 break; 483 } 484 } 485 } 486 487 for (auto &blockArg : blockArgsToRemove) { 488 blockArgsToDetensor.erase(blockArg); 489 } 490 } 491 }; 492 493 /// Detensorize everything that can detensored. 494 class AggressiveDetensoringModel : public CostModel { 495 public: 496 void compute(Operation *func, DetensorizeTypeConverter typeConverter, 497 DenseSet<Operation *> &opsToDetensor, 498 DenseSet<BlockArgument> &blockArgsToDetensor) override { 499 func->walk([&](GenericOp genericOp) { 500 if (shouldBeDetensored(genericOp, typeConverter)) 501 opsToDetensor.insert(genericOp); 502 }); 503 504 for (Block &block : 505 llvm::drop_begin(function_like_impl::getFunctionBody(func), 1)) 506 for (BlockArgument blockArgument : block.getArguments()) 507 blockArgsToDetensor.insert(blockArgument); 508 } 509 }; 510 511 void runOnOperation() override { 512 assert(getOperation()->hasTrait<OpTrait::FunctionLike>() && 513 "DetensorizePass can only be run on FunctionLike operations"); 514 MLIRContext *context = &getContext(); 515 DetensorizeTypeConverter typeConverter; 516 RewritePatternSet patterns(context); 517 ConversionTarget target(*context); 518 DenseSet<Operation *> opsToDetensor; 519 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 520 DenseSet<BlockArgument> blockArgsToDetensor; 521 522 if (aggressiveMode.getValue()) { 523 AggressiveDetensoringModel costModel; 524 costModel.compute(getOperation(), typeConverter, opsToDetensor, 525 blockArgsToDetensor); 526 527 } else { 528 ControlFlowDetectionModel costModel; 529 costModel.compute(getOperation(), typeConverter, opsToDetensor, 530 blockArgsToDetensor); 531 } 532 533 detensorableBranchOps = 534 CostModel::computeBranchOpDetensoring(blockArgsToDetensor); 535 536 target.addDynamicallyLegalOp<GenericOp>( 537 [&](GenericOp op) { return !opsToDetensor.count(op); }); 538 539 target.markUnknownOpDynamicallyLegal([&](Operation *op) { 540 // A function is legal if all of its non-entry blocks are legal. We 541 // don't legalize the entry block (i.e. the function's signature) 542 // since detensoring can't happen along external calling convention 543 // boundaries, which we conservatively approximate as all function 544 // signatures. 545 if (op->hasTrait<OpTrait::FunctionLike>()) { 546 auto &body = function_like_impl::getFunctionBody(op); 547 return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) { 548 if (llvm::any_of( 549 blockArgsToDetensor, [&](BlockArgument blockArgument) { 550 return blockArgument.getOwner() == &block && 551 !typeConverter.isLegal(blockArgument.getType()); 552 })) { 553 return false; 554 } 555 return true; 556 }); 557 } 558 559 if (isNotBranchOpInterfaceOrReturnLikeOp(op) || 560 isLegalForReturnOpTypeConversionPattern(op, typeConverter, 561 /*returnOpAlwaysLegal*/ true)) 562 return true; 563 564 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { 565 if (!detensorableBranchOps.count(branchOp)) 566 return true; 567 568 for (auto operandIdx : detensorableBranchOps[branchOp]) 569 if (!typeConverter.isLegal( 570 branchOp->getOperand(operandIdx).getType())) 571 return false; 572 573 return true; 574 } 575 576 return false; 577 }); 578 579 patterns.insert<DetensorizeGenericOp>(typeConverter, context); 580 patterns.insert<FunctionNonEntryBlockConversion>(context, typeConverter, 581 blockArgsToDetensor); 582 // Since non-entry block arguments get detensorized, we also need to 583 // update the control flow inside the function to reflect the correct 584 // types. 585 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp, 586 int operandIdx) -> bool { 587 return detensorableBranchOps.count(branchOp) && 588 detensorableBranchOps[branchOp].count(operandIdx); 589 }; 590 591 populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, 592 shouldConvertBranchOperand); 593 594 if (failed( 595 applyFullConversion(getOperation(), target, std::move(patterns)))) 596 signalPassFailure(); 597 598 RewritePatternSet canonPatterns(context); 599 canonPatterns.add<ExtractFromReshapeFromElements>(context); 600 if (failed(applyPatternsAndFoldGreedily(getOperation(), 601 std::move(canonPatterns)))) 602 signalPassFailure(); 603 } 604 605 Option<bool> aggressiveMode{ 606 *this, "aggressive-mode", 607 llvm::cl::desc("Detensorize all ops that qualify for detensoring along " 608 "with branch operands and basic-block arguments.")}; 609 }; 610 } // namespace 611 612 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() { 613 return std::make_unique<LinalgDetensorize>(); 614 } 615