1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 11 #include "mlir/Dialect/Linalg/IR/Linalg.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/OpDefinition.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 #include <iterator> 19 #include <memory> 20 #include <utility> 21 22 using namespace mlir; 23 using namespace mlir::linalg; 24 25 static Value sourceMaterializationCallback(OpBuilder &builder, Type type, 26 ValueRange inputs, Location loc) { 27 assert(inputs.size() == 1); 28 auto inputType = inputs[0].getType(); 29 if (inputType.isa<TensorType>()) 30 return nullptr; 31 32 // A detensored value is converted back by creating a new tensor from its 33 // element(s). 34 return builder.create<tensor::FromElementsOp>( 35 loc, RankedTensorType::get({}, inputType), inputs[0]); 36 } 37 38 namespace { 39 /// Defines the criteria a TensorType must follow in order to be considered 40 /// "detensorable". 41 /// 42 /// NOTE: For now, only 0-D tensors are supported. 43 /// 44 /// Returns true if tensorType can be detensored. 45 bool canBeDetensored(TensorType tensorType) { 46 return tensorType.hasRank() && tensorType.getRank() == 0; 47 } 48 49 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { 50 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op); 51 return genericOp && 52 llvm::all_of( 53 genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) { 54 return !typeConverter.isLegal(opOperand->get().getType()); 55 }); 56 } 57 58 /// A conversion patttern for detensoring `linalg.generic` ops. 59 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> { 60 public: 61 using OpConversionPattern::OpConversionPattern; 62 LogicalResult 63 matchAndRewrite(GenericOp op, OpAdaptor adaptor, 64 ConversionPatternRewriter &rewriter) const override { 65 Block *originalBlock = op->getBlock(); 66 67 // Gather some information about the op before inling its region. 68 Block *opEntryBlock = &*op.region().begin(); 69 YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator()); 70 71 // Split the op's region before the op. This way, we have a clear insertion 72 // point in which the op can be inlined. 73 Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op)); 74 rewriter.inlineRegionBefore(op.region(), newBlock); 75 // Now that op's region is inlined, the operands of its YieldOp are mapped 76 // to the materialized target values. Therefore, we can replace the op's 77 // uses with those of its YielOp's operands. 78 rewriter.replaceOp(op, yieldOp->getOperands()); 79 80 // No need for these intermediate blocks, merge them into 1. 81 rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands()); 82 rewriter.mergeBlocks(newBlock, originalBlock, {}); 83 84 rewriter.eraseOp(&*Block::iterator(yieldOp)); 85 86 return success(); 87 } 88 }; 89 90 /// A conversion pattern for detensoring internal (non-entry) blocks within a 91 /// function. 92 struct FunctionNonEntryBlockConversion 93 : public OpInterfaceConversionPattern<FunctionOpInterface> { 94 FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter, 95 DenseSet<BlockArgument> blockArgsToDetensor) 96 : OpInterfaceConversionPattern(converter, ctx), 97 blockArgsToDetensor(std::move(blockArgsToDetensor)) {} 98 99 LogicalResult 100 matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands, 101 ConversionPatternRewriter &rewriter) const override { 102 rewriter.startRootUpdate(op); 103 Region ®ion = op.getBody(); 104 SmallVector<TypeConverter::SignatureConversion, 2> conversions; 105 106 for (Block &block : llvm::drop_begin(region, 1)) { 107 conversions.emplace_back(block.getNumArguments()); 108 TypeConverter::SignatureConversion &back = conversions.back(); 109 110 for (BlockArgument blockArgument : block.getArguments()) { 111 int idx = blockArgument.getArgNumber(); 112 113 if (blockArgsToDetensor.count(blockArgument)) 114 back.addInputs(idx, {getTypeConverter()->convertType( 115 block.getArgumentTypes()[idx])}); 116 else 117 back.addInputs(idx, {block.getArgumentTypes()[idx]}); 118 } 119 } 120 121 if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, 122 conversions))) { 123 rewriter.cancelRootUpdate(op); 124 return failure(); 125 } 126 127 rewriter.finalizeRootUpdate(op); 128 return success(); 129 } 130 131 private: 132 const DenseSet<BlockArgument> blockArgsToDetensor; 133 }; 134 135 class DetensorizeTypeConverter : public TypeConverter { 136 public: 137 DetensorizeTypeConverter() { 138 addConversion([](Type type) { return type; }); 139 140 // A TensorType that can be detensored, is converted to the underlying 141 // element type. 142 addConversion([](TensorType tensorType) -> Type { 143 if (canBeDetensored(tensorType)) 144 return tensorType.getElementType(); 145 146 return tensorType; 147 }); 148 149 // A tensor value is detensoried by extracting its element(s). 150 addTargetMaterialization([](OpBuilder &builder, Type type, 151 ValueRange inputs, Location loc) -> Value { 152 return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{}); 153 }); 154 155 addSourceMaterialization(sourceMaterializationCallback); 156 addArgumentMaterialization(sourceMaterializationCallback); 157 } 158 }; 159 160 /// @see LinalgDetensorize in Linalg/Passes.td for more details. 161 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> { 162 LinalgDetensorize() = default; 163 164 class CostModel { 165 public: 166 virtual ~CostModel() = default; 167 168 /// A cost model algorithm computes the following outputs: 169 /// 170 /// - opsToDetensor: the list of linalg ops that should be 171 /// detensored. 172 /// 173 /// - blockArgsToDetensor: since the operands and results of detensored 174 /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come 175 /// from a BB argument and a linalg op's output can be passed to successor 176 /// BBs), we need to maintain the sub-set of arguments that should be 177 /// detensored (i.e. converted by typeConverter) for each affected BB. 178 /// 179 /// Example: 180 /// 181 /// For the following snippet: 182 /// ... 183 /// ^bb1(%6: tensor<i32>, %9: tensor<i32>): 184 /// %7 = linalg.init_tensor [] : tensor<i32> 185 /// %8 = linalg.generic #attrs 186 /// ins(%6, %6 : tensor<i32>, tensor<i32>) 187 /// outs(%7 : tensor<i32>) { 188 /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): 189 /// %9 = arith.addi %arg0, %arg1 : i32 190 /// linalg.yield %9 : i32 191 /// } -> tensor<i32> 192 /// %10 = "some.op"(%9) 193 /// br ^bb2(%8 : tensor<i32>) 194 /// ... 195 /// 196 /// if the cost model decides that the linalg.generic op should be 197 /// detensored, then: 198 /// - opsToDetensor should be = {linalg.generic{add}}. 199 /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}. 200 virtual void compute(FunctionOpInterface func, 201 DetensorizeTypeConverter typeConverter, 202 DenseSet<Operation *> &opsToDetensor, 203 DenseSet<BlockArgument> &blockArgsToDetensor) = 0; 204 205 /// From the blockArgsToDetensor set computed by a CostModel 206 /// implementation, this method computes the corresponding branch op 207 /// detensoring. The result is a map from a branch op to a subset of indices 208 /// of its operands. The indices specify which of the branch op's operands 209 /// should be detensored. 210 /// 211 /// For the previous example, this method would compute: {bb2 -> {0}}. 212 static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring( 213 const DenseSet<BlockArgument> &blockArgsToDetensor) { 214 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 215 216 for (auto blockArgumentElem : blockArgsToDetensor) { 217 Block *block = blockArgumentElem.getOwner(); 218 219 for (PredecessorIterator pred = block->pred_begin(); 220 pred != block->pred_end(); ++pred) { 221 BranchOpInterface terminator = 222 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 223 auto blockOperands = 224 terminator.getSuccessorOperands(pred.getSuccessorIndex()); 225 226 if (!blockOperands || blockOperands->empty()) 227 continue; 228 229 detensorableBranchOps[terminator].insert( 230 blockOperands->getBeginOperandIndex() + 231 blockArgumentElem.getArgNumber()); 232 } 233 } 234 235 return detensorableBranchOps; 236 } 237 }; 238 239 /// Detensorize linalg ops involved in control-flow within a function. 240 /// 241 /// This model starts from BranchOps and CondBranchOps within a function. For 242 /// each such branch, the model then walks the use-def chain for the branch's 243 /// condition backwards in order to understand where the condition's value 244 /// comes from. If the condition value is (indirectly) computed by a linalg op 245 /// that can be detensored, the model then continues walking the use-def chain 246 /// in order to understand where the linalg op's operands come from. This 247 /// leads to discovering a "detensoring component". A detensoring component is 248 /// the set of operations + block arguments that are involved in control-flow 249 /// AND can be detensored. 250 class ControlFlowDetectionModel : public CostModel { 251 public: 252 void compute(FunctionOpInterface func, 253 DetensorizeTypeConverter typeConverter, 254 DenseSet<Operation *> &opsToDetensor, 255 DenseSet<BlockArgument> &blockArgsToDetensor) override { 256 SmallVector<Value> workList; 257 258 func->walk([&](cf::CondBranchOp condBr) { 259 for (auto operand : condBr.getOperands()) { 260 workList.push_back(operand); 261 } 262 }); 263 264 func->walk([&](cf::BranchOp br) { 265 for (auto operand : br.getOperands()) { 266 workList.push_back(operand); 267 } 268 }); 269 270 DenseSet<Value> visitedValues; 271 DenseSet<Operation *> visitedOps; 272 273 // For a (to-be-detesored) value, check if it "escapes" the block by being 274 // passed to terminator. If it does, then workList is updated with the 275 // corresponding argument to the successor block. 276 auto updateWorkListWithSuccessorArguments = 277 [&](Value value, BranchOpInterface terminator) { 278 if (!terminator) 279 return; 280 281 for (auto operandIdx : 282 llvm::seq<unsigned>(0, terminator->getOperands().size())) { 283 Value operand = terminator->getOperand(operandIdx); 284 285 if (operand == value) { 286 auto succBlockArg = 287 terminator.getSuccessorBlockArgument(operandIdx); 288 289 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg)) 290 workList.push_back(*succBlockArg); 291 } 292 } 293 }; 294 295 while (!workList.empty()) { 296 Value currentItem = workList.pop_back_val(); 297 298 if (!visitedValues.insert(currentItem).second) 299 continue; 300 301 // 1 - Look forward: 302 // 1.1 - If currentItem escapes to one or more successors, add 303 // the corresponding successor arguments to workList. 304 updateWorkListWithSuccessorArguments( 305 currentItem, dyn_cast<BranchOpInterface>( 306 currentItem.getParentBlock()->getTerminator())); 307 308 // 1.2 - For each user of currentItem, add the defined values to 309 // workList. This way, the user ops can be inspected later if they are 310 // detensorable and if so, their operands will be added to workList to 311 // potentially discover other parts of the detensorable component. 312 for (auto *user : currentItem.getUsers()) 313 for (Value result : user->getResults()) 314 workList.push_back(result); 315 316 // 2 - Look backward: 317 // 2.1 - The current item is defined by a block argument. If the owner 318 // block is a non-entry one, then: 319 // * Add the argument to blockArgsToDetensor. 320 // * Walk the use-def chain backwards to add each predecessor's 321 // terminator-operands corresponding to currentItem to workList. 322 if (currentItem.dyn_cast<BlockArgument>()) { 323 BlockArgument currentItemBlockArgument = 324 currentItem.cast<BlockArgument>(); 325 Block *ownerBlock = currentItemBlockArgument.getOwner(); 326 327 // Function arguments are not detensored/converted. 328 if (&*ownerBlock->getParent()->begin() == ownerBlock) 329 continue; 330 331 // This inner-block argument is involved in control-flow, it should be 332 // detensored. 333 blockArgsToDetensor.insert(currentItemBlockArgument); 334 335 for (PredecessorIterator pred = ownerBlock->pred_begin(); 336 pred != ownerBlock->pred_end(); ++pred) { 337 BranchOpInterface predTerminator = 338 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 339 340 // TODO: For now, we give up if any of the control-flow components 341 // in a function is not detensorable. Fix that. 342 if (!predTerminator) { 343 opsToDetensor.clear(); 344 blockArgsToDetensor.clear(); 345 return; 346 } 347 348 auto ownerBlockOperands = 349 predTerminator.getSuccessorOperands(pred.getSuccessorIndex()); 350 351 if (!ownerBlockOperands || ownerBlockOperands->empty()) 352 continue; 353 354 // For each predecessor, add the value it passes to that argument to 355 // workList to find out how it's computed. 356 workList.push_back( 357 ownerBlockOperands 358 .getValue()[currentItemBlockArgument.getArgNumber()]); 359 } 360 361 continue; 362 } 363 364 Operation *currentItemDefiningOp = currentItem.getDefiningOp(); 365 366 if (!visitedOps.insert(currentItemDefiningOp).second) 367 continue; 368 369 // 2.2 - The current item is computed by a GenericOp. If the op should 370 // be detensored, then: 371 // * Add it to opsToDetensor. 372 // * Add its operands to workList to discover other parts of the 373 // potentially detensorable component. 374 if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) { 375 // The op was encountered already, no need to inspect it again. 376 if (opsToDetensor.count(genericOp)) 377 continue; 378 379 // The op should not be detensored, give up on it but continue with 380 // discovering the rest of the control-flow component. 381 if (!shouldBeDetensored(genericOp, typeConverter)) { 382 continue; 383 } 384 385 opsToDetensor.insert(genericOp); 386 387 for (Value genericOpOperand : genericOp.inputs()) 388 workList.push_back(genericOpOperand); 389 390 continue; 391 } 392 393 // 2.3 - The current item is the result of a FromElementsOp, it will be 394 // trivially detensored later as part of canonicalization patterns 395 // applied at the end of detensoring. 396 // 397 // Note: No need to check whether the result type of this op is 398 // detensorable since if it wasn't we wouldn't reach that point in the 399 // work list. 400 if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp)) 401 continue; 402 403 // 2.4 - The current item is the result of a scalar op, add all its 404 // operands to the work list. 405 if (llvm::all_of( 406 currentItemDefiningOp->getResultTypes(), 407 [&](Type resultType) { return resultType.isIntOrFloat(); })) 408 for (Value scalarOpOperand : currentItemDefiningOp->getOperands()) 409 workList.push_back(scalarOpOperand); 410 } 411 412 // Since the cost model gives up on some ops (see the details of step 2.2 413 // above), block arguments that correspond to the values produced by those 414 // ops should not be detensored as well. 415 416 DenseSet<BlockArgument> blockArgsToRemove; 417 418 for (auto &blockArg : blockArgsToDetensor) { 419 Block *block = blockArg.getParentBlock(); 420 421 // For the potentially detensorable block argument, find the 422 // correpsonding operands in predecessor blocks. 423 for (PredecessorIterator pred = block->pred_begin(); 424 pred != block->pred_end(); ++pred) { 425 BranchOpInterface terminator = 426 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 427 auto blockOperands = 428 terminator.getSuccessorOperands(pred.getSuccessorIndex()); 429 430 if (!blockOperands || blockOperands->empty()) 431 continue; 432 433 Operation *definingOp = 434 terminator 435 ->getOperand(blockOperands->getBeginOperandIndex() + 436 blockArg.getArgNumber()) 437 .getDefiningOp(); 438 439 // If the operand is defined by a GenericOp that will not be 440 // detensored, then do not detensor the corresponding block argument. 441 if (dyn_cast_or_null<GenericOp>(definingOp) && 442 opsToDetensor.count(definingOp) == 0) { 443 blockArgsToRemove.insert(blockArg); 444 break; 445 } 446 } 447 } 448 449 for (auto &blockArg : blockArgsToRemove) { 450 blockArgsToDetensor.erase(blockArg); 451 } 452 } 453 }; 454 455 /// Detensorize everything that can detensored. 456 class AggressiveDetensoringModel : public CostModel { 457 public: 458 void compute(FunctionOpInterface func, 459 DetensorizeTypeConverter typeConverter, 460 DenseSet<Operation *> &opsToDetensor, 461 DenseSet<BlockArgument> &blockArgsToDetensor) override { 462 func->walk([&](GenericOp genericOp) { 463 if (shouldBeDetensored(genericOp, typeConverter)) 464 opsToDetensor.insert(genericOp); 465 }); 466 467 for (Block &block : llvm::drop_begin(func.getBody(), 1)) 468 for (BlockArgument blockArgument : block.getArguments()) 469 blockArgsToDetensor.insert(blockArgument); 470 } 471 }; 472 473 void runOnOperation() override { 474 MLIRContext *context = &getContext(); 475 DetensorizeTypeConverter typeConverter; 476 RewritePatternSet patterns(context); 477 ConversionTarget target(*context); 478 DenseSet<Operation *> opsToDetensor; 479 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 480 DenseSet<BlockArgument> blockArgsToDetensor; 481 FunctionOpInterface funcOp = cast<FunctionOpInterface>(getOperation()); 482 483 if (aggressiveMode.getValue()) { 484 AggressiveDetensoringModel costModel; 485 costModel.compute(funcOp, typeConverter, opsToDetensor, 486 blockArgsToDetensor); 487 } else { 488 ControlFlowDetectionModel costModel; 489 costModel.compute(funcOp, typeConverter, opsToDetensor, 490 blockArgsToDetensor); 491 } 492 493 detensorableBranchOps = 494 CostModel::computeBranchOpDetensoring(blockArgsToDetensor); 495 496 target.addDynamicallyLegalOp<GenericOp>( 497 [&](GenericOp op) { return !opsToDetensor.count(op); }); 498 499 target.markUnknownOpDynamicallyLegal([&](Operation *op) { 500 // A function is legal if all of its non-entry blocks are legal. We 501 // don't legalize the entry block (i.e. the function's signature) 502 // since detensoring can't happen along external calling convention 503 // boundaries, which we conservatively approximate as all function 504 // signatures. 505 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) { 506 Region &body = funcOp.getBody(); 507 return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) { 508 return !llvm::any_of( 509 blockArgsToDetensor, [&](BlockArgument blockArgument) { 510 return blockArgument.getOwner() == &block && 511 !typeConverter.isLegal(blockArgument.getType()); 512 }); 513 }); 514 } 515 516 if (isNotBranchOpInterfaceOrReturnLikeOp(op) || 517 isLegalForReturnOpTypeConversionPattern(op, typeConverter, 518 /*returnOpAlwaysLegal*/ true)) 519 return true; 520 521 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { 522 if (!detensorableBranchOps.count(branchOp)) 523 return true; 524 525 for (auto operandIdx : detensorableBranchOps[branchOp]) 526 if (!typeConverter.isLegal( 527 branchOp->getOperand(operandIdx).getType())) 528 return false; 529 530 return true; 531 } 532 533 return false; 534 }); 535 536 patterns.insert<DetensorizeGenericOp>(typeConverter, context); 537 patterns.insert<FunctionNonEntryBlockConversion>(context, typeConverter, 538 blockArgsToDetensor); 539 // Since non-entry block arguments get detensorized, we also need to 540 // update the control flow inside the function to reflect the correct 541 // types. 542 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp, 543 int operandIdx) -> bool { 544 return detensorableBranchOps.count(branchOp) && 545 detensorableBranchOps[branchOp].count(operandIdx); 546 }; 547 548 populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, 549 shouldConvertBranchOperand); 550 551 if (failed( 552 applyFullConversion(getOperation(), target, std::move(patterns)))) 553 signalPassFailure(); 554 555 RewritePatternSet canonPatterns(context); 556 tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context); 557 if (failed(applyPatternsAndFoldGreedily(getOperation(), 558 std::move(canonPatterns)))) 559 signalPassFailure(); 560 } 561 }; 562 } // namespace 563 564 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() { 565 return std::make_unique<LinalgDetensorize>(); 566 } 567