1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Linalg/IR/Linalg.h" 11 #include "mlir/Dialect/Linalg/Passes.h" 12 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/OpDefinition.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 #include <iterator> 18 #include <memory> 19 #include <utility> 20 21 using namespace mlir; 22 using namespace mlir::linalg; 23 24 static Value sourceMaterializationCallback(OpBuilder &builder, Type type, 25 ValueRange inputs, Location loc) { 26 assert(inputs.size() == 1); 27 auto inputType = inputs[0].getType(); 28 if (inputType.isa<TensorType>()) 29 return nullptr; 30 31 // A detensored value is converted back by creating a new tensor from its 32 // element(s). 33 return builder.create<tensor::FromElementsOp>( 34 loc, RankedTensorType::get({}, inputType), inputs[0]); 35 } 36 37 namespace { 38 /// Defines the criteria a TensorType must follow in order to be considered 39 /// "detensorable". 40 /// 41 /// NOTE: For now, only 0-D tensors are supported. 42 /// 43 /// Returns true if tensorType can be detensored. 44 bool canBeDetensored(TensorType tensorType) { 45 return tensorType.hasRank() && tensorType.getRank() == 0; 46 } 47 48 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { 49 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op); 50 return genericOp && 51 llvm::all_of( 52 genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) { 53 return !typeConverter.isLegal(opOperand->get().getType()); 54 }); 55 } 56 57 /// A conversion patttern for detensoring `linalg.generic` ops. 58 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> { 59 public: 60 using OpConversionPattern::OpConversionPattern; 61 LogicalResult 62 matchAndRewrite(GenericOp op, OpAdaptor adaptor, 63 ConversionPatternRewriter &rewriter) const override { 64 Block *originalBlock = op->getBlock(); 65 66 // Gather some information about the op before inling its region. 67 Block *opEntryBlock = &*op.region().begin(); 68 YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator()); 69 70 // Split the op's region before the op. This way, we have a clear insertion 71 // point in which the op can be inlined. 72 Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op)); 73 rewriter.inlineRegionBefore(op.region(), newBlock); 74 // Now that op's region is inlined, the operands of its YieldOp are mapped 75 // to the materialized target values. Therefore, we can replace the op's 76 // uses with those of its YielOp's operands. 77 rewriter.replaceOp(op, yieldOp->getOperands()); 78 79 // No need for these intermediate blocks, merge them into 1. 80 rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands()); 81 rewriter.mergeBlocks(newBlock, originalBlock, {}); 82 83 rewriter.eraseOp(&*Block::iterator(yieldOp)); 84 85 return success(); 86 } 87 }; 88 89 /// A conversion pattern for detensoring internal (non-entry) blocks within a 90 /// function. 91 struct FunctionNonEntryBlockConversion 92 : public OpInterfaceConversionPattern<FunctionOpInterface> { 93 FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter, 94 DenseSet<BlockArgument> blockArgsToDetensor) 95 : OpInterfaceConversionPattern(converter, ctx), 96 blockArgsToDetensor(std::move(blockArgsToDetensor)) {} 97 98 LogicalResult 99 matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands, 100 ConversionPatternRewriter &rewriter) const override { 101 rewriter.startRootUpdate(op); 102 Region ®ion = op.getBody(); 103 SmallVector<TypeConverter::SignatureConversion, 2> conversions; 104 105 for (Block &block : llvm::drop_begin(region, 1)) { 106 conversions.emplace_back(block.getNumArguments()); 107 TypeConverter::SignatureConversion &back = conversions.back(); 108 109 for (BlockArgument blockArgument : block.getArguments()) { 110 int idx = blockArgument.getArgNumber(); 111 112 if (blockArgsToDetensor.count(blockArgument)) 113 back.addInputs(idx, {getTypeConverter()->convertType( 114 block.getArgumentTypes()[idx])}); 115 else 116 back.addInputs(idx, {block.getArgumentTypes()[idx]}); 117 } 118 } 119 120 if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, 121 conversions))) { 122 rewriter.cancelRootUpdate(op); 123 return failure(); 124 } 125 126 rewriter.finalizeRootUpdate(op); 127 return success(); 128 } 129 130 private: 131 const DenseSet<BlockArgument> blockArgsToDetensor; 132 }; 133 134 class DetensorizeTypeConverter : public TypeConverter { 135 public: 136 DetensorizeTypeConverter() { 137 addConversion([](Type type) { return type; }); 138 139 // A TensorType that can be detensored, is converted to the underlying 140 // element type. 141 addConversion([](TensorType tensorType) -> Type { 142 if (canBeDetensored(tensorType)) 143 return tensorType.getElementType(); 144 145 return tensorType; 146 }); 147 148 // A tensor value is detensoried by extracting its element(s). 149 addTargetMaterialization([](OpBuilder &builder, Type type, 150 ValueRange inputs, Location loc) -> Value { 151 return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{}); 152 }); 153 154 addSourceMaterialization(sourceMaterializationCallback); 155 addArgumentMaterialization(sourceMaterializationCallback); 156 } 157 }; 158 159 /// @see LinalgDetensorize in Linalg/Passes.td for more details. 160 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> { 161 LinalgDetensorize() = default; 162 163 class CostModel { 164 public: 165 virtual ~CostModel() = default; 166 167 /// A cost model algorithm computes the following outputs: 168 /// 169 /// - opsToDetensor: the list of linalg ops that should be 170 /// detensored. 171 /// 172 /// - blockArgsToDetensor: since the operands and results of detensored 173 /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come 174 /// from a BB argument and a linalg op's output can be passed to successor 175 /// BBs), we need to maintain the sub-set of arguments that should be 176 /// detensored (i.e. converted by typeConverter) for each affected BB. 177 /// 178 /// Example: 179 /// 180 /// For the following snippet: 181 /// ... 182 /// ^bb1(%6: tensor<i32>, %9: tensor<i32>): 183 /// %7 = linalg.init_tensor [] : tensor<i32> 184 /// %8 = linalg.generic #attrs 185 /// ins(%6, %6 : tensor<i32>, tensor<i32>) 186 /// outs(%7 : tensor<i32>) { 187 /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): 188 /// %9 = arith.addi %arg0, %arg1 : i32 189 /// linalg.yield %9 : i32 190 /// } -> tensor<i32> 191 /// %10 = "some.op"(%9) 192 /// br ^bb2(%8 : tensor<i32>) 193 /// ... 194 /// 195 /// if the cost model decides that the linalg.generic op should be 196 /// detensored, then: 197 /// - opsToDetensor should be = {linalg.generic{add}}. 198 /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}. 199 virtual void compute(FunctionOpInterface func, 200 DetensorizeTypeConverter typeConverter, 201 DenseSet<Operation *> &opsToDetensor, 202 DenseSet<BlockArgument> &blockArgsToDetensor) = 0; 203 204 /// From the blockArgsToDetensor set computed by a CostModel 205 /// implementation, this method computes the corresponding branch op 206 /// detensoring. The result is a map from a branch op to a subset of indices 207 /// of its operands. The indices specify which of the branch op's operands 208 /// should be detensored. 209 /// 210 /// For the previous example, this method would compute: {bb2 -> {0}}. 211 static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring( 212 const DenseSet<BlockArgument> &blockArgsToDetensor) { 213 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 214 215 for (auto blockArgumentElem : blockArgsToDetensor) { 216 Block *block = blockArgumentElem.getOwner(); 217 218 for (PredecessorIterator pred = block->pred_begin(); 219 pred != block->pred_end(); ++pred) { 220 BranchOpInterface terminator = 221 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 222 auto blockOperands = 223 terminator.getSuccessorOperands(pred.getSuccessorIndex()); 224 225 if (!blockOperands || blockOperands->empty()) 226 continue; 227 228 detensorableBranchOps[terminator].insert( 229 blockOperands->getBeginOperandIndex() + 230 blockArgumentElem.getArgNumber()); 231 } 232 } 233 234 return detensorableBranchOps; 235 } 236 }; 237 238 /// Detensorize linalg ops involved in control-flow within a function. 239 /// 240 /// This model starts from BranchOps and CondBranchOps within a function. For 241 /// each such branch, the model then walks the use-def chain for the branch's 242 /// condition backwards in order to understand where the condition's value 243 /// comes from. If the condition value is (indirectly) computed by a linalg op 244 /// that can be detensored, the model then continues walking the use-def chain 245 /// in order to understand where the linalg op's operands come from. This 246 /// leads to discovering a "detensoring component". A detensoring component is 247 /// the set of operations + block arguments that are involved in control-flow 248 /// AND can be detensored. 249 class ControlFlowDetectionModel : public CostModel { 250 public: 251 void compute(FunctionOpInterface func, 252 DetensorizeTypeConverter typeConverter, 253 DenseSet<Operation *> &opsToDetensor, 254 DenseSet<BlockArgument> &blockArgsToDetensor) override { 255 SmallVector<Value> workList; 256 257 func->walk([&](CondBranchOp condBr) { 258 for (auto operand : condBr.getOperands()) { 259 workList.push_back(operand); 260 } 261 }); 262 263 func->walk([&](BranchOp br) { 264 for (auto operand : br.getOperands()) { 265 workList.push_back(operand); 266 } 267 }); 268 269 DenseSet<Value> visitedValues; 270 DenseSet<Operation *> visitedOps; 271 272 // For a (to-be-detesored) value, check if it "escapes" the block by being 273 // passed to terminator. If it does, then workList is updated with the 274 // corresponding argument to the successor block. 275 auto updateWorkListWithSuccessorArguments = 276 [&](Value value, BranchOpInterface terminator) { 277 if (!terminator) 278 return; 279 280 for (auto operandIdx : 281 llvm::seq<unsigned>(0, terminator->getOperands().size())) { 282 Value operand = terminator->getOperand(operandIdx); 283 284 if (operand == value) { 285 auto succBlockArg = 286 terminator.getSuccessorBlockArgument(operandIdx); 287 288 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg)) 289 workList.push_back(*succBlockArg); 290 } 291 } 292 }; 293 294 while (!workList.empty()) { 295 Value currentItem = workList.pop_back_val(); 296 297 if (!visitedValues.insert(currentItem).second) 298 continue; 299 300 // 1 - Look forward: 301 // 1.1 - If currentItem escapes to one or more successors, add 302 // the corresponding successor arguments to workList. 303 updateWorkListWithSuccessorArguments( 304 currentItem, dyn_cast<BranchOpInterface>( 305 currentItem.getParentBlock()->getTerminator())); 306 307 // 1.2 - For each user of currentItem, add the defined values to 308 // workList. This way, the user ops can be inspected later if they are 309 // detensorable and if so, their operands will be added to workList to 310 // potentially discover other parts of the detensorable component. 311 for (auto *user : currentItem.getUsers()) 312 for (Value result : user->getResults()) 313 workList.push_back(result); 314 315 // 2 - Look backward: 316 // 2.1 - The current item is defined by a block argument. If the owner 317 // block is a non-entry one, then: 318 // * Add the argument to blockArgsToDetensor. 319 // * Walk the use-def chain backwards to add each predecessor's 320 // terminator-operands corresponding to currentItem to workList. 321 if (currentItem.dyn_cast<BlockArgument>()) { 322 BlockArgument currentItemBlockArgument = 323 currentItem.cast<BlockArgument>(); 324 Block *ownerBlock = currentItemBlockArgument.getOwner(); 325 326 // Function arguments are not detensored/converted. 327 if (&*ownerBlock->getParent()->begin() == ownerBlock) 328 continue; 329 330 // This inner-block argument is involved in control-flow, it should be 331 // detensored. 332 blockArgsToDetensor.insert(currentItemBlockArgument); 333 334 for (PredecessorIterator pred = ownerBlock->pred_begin(); 335 pred != ownerBlock->pred_end(); ++pred) { 336 BranchOpInterface predTerminator = 337 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 338 339 // TODO: For now, we give up if any of the control-flow components 340 // in a function is not detensorable. Fix that. 341 if (!predTerminator) { 342 opsToDetensor.clear(); 343 blockArgsToDetensor.clear(); 344 return; 345 } 346 347 auto ownerBlockOperands = 348 predTerminator.getSuccessorOperands(pred.getSuccessorIndex()); 349 350 if (!ownerBlockOperands || ownerBlockOperands->empty()) 351 continue; 352 353 // For each predecessor, add the value it passes to that argument to 354 // workList to find out how it's computed. 355 workList.push_back( 356 ownerBlockOperands 357 .getValue()[currentItemBlockArgument.getArgNumber()]); 358 } 359 360 continue; 361 } 362 363 Operation *currentItemDefiningOp = currentItem.getDefiningOp(); 364 365 if (!visitedOps.insert(currentItemDefiningOp).second) 366 continue; 367 368 // 2.2 - The current item is computed by a GenericOp. If the op should 369 // be detensored, then: 370 // * Add it to opsToDetensor. 371 // * Add its operands to workList to discover other parts of the 372 // potentially detensorable component. 373 if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) { 374 // The op was encountered already, no need to inspect it again. 375 if (opsToDetensor.count(genericOp)) 376 continue; 377 378 // The op should not be detensored, give up on it but continue with 379 // discovering the rest of the control-flow component. 380 if (!shouldBeDetensored(genericOp, typeConverter)) { 381 continue; 382 } 383 384 opsToDetensor.insert(genericOp); 385 386 for (Value genericOpOperand : genericOp.inputs()) 387 workList.push_back(genericOpOperand); 388 389 continue; 390 } 391 392 // 2.3 - The current item is the result of a FromElementsOp, it will be 393 // trivially detensored later as part of canonicalization patterns 394 // applied at the end of detensoring. 395 // 396 // Note: No need to check whether the result type of this op is 397 // detensorable since if it wasn't we wouldn't reach that point in the 398 // work list. 399 if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp)) 400 continue; 401 402 // 2.4 - The current item is the result of a scalar op, add all its 403 // operands to the work list. 404 if (llvm::all_of( 405 currentItemDefiningOp->getResultTypes(), 406 [&](Type resultType) { return resultType.isIntOrFloat(); })) 407 for (Value scalarOpOperand : currentItemDefiningOp->getOperands()) 408 workList.push_back(scalarOpOperand); 409 } 410 411 // Since the cost model gives up on some ops (see the details of step 2.2 412 // above), block arguments that correspond to the values produced by those 413 // ops should not be detensored as well. 414 415 DenseSet<BlockArgument> blockArgsToRemove; 416 417 for (auto &blockArg : blockArgsToDetensor) { 418 Block *block = blockArg.getParentBlock(); 419 420 // For the potentially detensorable block argument, find the 421 // correpsonding operands in predecessor blocks. 422 for (PredecessorIterator pred = block->pred_begin(); 423 pred != block->pred_end(); ++pred) { 424 BranchOpInterface terminator = 425 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 426 auto blockOperands = 427 terminator.getSuccessorOperands(pred.getSuccessorIndex()); 428 429 if (!blockOperands || blockOperands->empty()) 430 continue; 431 432 Operation *definingOp = 433 terminator 434 ->getOperand(blockOperands->getBeginOperandIndex() + 435 blockArg.getArgNumber()) 436 .getDefiningOp(); 437 438 // If the operand is defined by a GenericOp that will not be 439 // detensored, then do not detensor the corresponding block argument. 440 if (dyn_cast_or_null<GenericOp>(definingOp) && 441 opsToDetensor.count(definingOp) == 0) { 442 blockArgsToRemove.insert(blockArg); 443 break; 444 } 445 } 446 } 447 448 for (auto &blockArg : blockArgsToRemove) { 449 blockArgsToDetensor.erase(blockArg); 450 } 451 } 452 }; 453 454 /// Detensorize everything that can detensored. 455 class AggressiveDetensoringModel : public CostModel { 456 public: 457 void compute(FunctionOpInterface func, 458 DetensorizeTypeConverter typeConverter, 459 DenseSet<Operation *> &opsToDetensor, 460 DenseSet<BlockArgument> &blockArgsToDetensor) override { 461 func->walk([&](GenericOp genericOp) { 462 if (shouldBeDetensored(genericOp, typeConverter)) 463 opsToDetensor.insert(genericOp); 464 }); 465 466 for (Block &block : llvm::drop_begin(func.getBody(), 1)) 467 for (BlockArgument blockArgument : block.getArguments()) 468 blockArgsToDetensor.insert(blockArgument); 469 } 470 }; 471 472 void runOnOperation() override { 473 MLIRContext *context = &getContext(); 474 DetensorizeTypeConverter typeConverter; 475 RewritePatternSet patterns(context); 476 ConversionTarget target(*context); 477 DenseSet<Operation *> opsToDetensor; 478 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 479 DenseSet<BlockArgument> blockArgsToDetensor; 480 FunctionOpInterface funcOp = cast<FunctionOpInterface>(getOperation()); 481 482 if (aggressiveMode.getValue()) { 483 AggressiveDetensoringModel costModel; 484 costModel.compute(funcOp, typeConverter, opsToDetensor, 485 blockArgsToDetensor); 486 } else { 487 ControlFlowDetectionModel costModel; 488 costModel.compute(funcOp, typeConverter, opsToDetensor, 489 blockArgsToDetensor); 490 } 491 492 detensorableBranchOps = 493 CostModel::computeBranchOpDetensoring(blockArgsToDetensor); 494 495 target.addDynamicallyLegalOp<GenericOp>( 496 [&](GenericOp op) { return !opsToDetensor.count(op); }); 497 498 target.markUnknownOpDynamicallyLegal([&](Operation *op) { 499 // A function is legal if all of its non-entry blocks are legal. We 500 // don't legalize the entry block (i.e. the function's signature) 501 // since detensoring can't happen along external calling convention 502 // boundaries, which we conservatively approximate as all function 503 // signatures. 504 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) { 505 Region &body = funcOp.getBody(); 506 return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) { 507 return !llvm::any_of( 508 blockArgsToDetensor, [&](BlockArgument blockArgument) { 509 return blockArgument.getOwner() == &block && 510 !typeConverter.isLegal(blockArgument.getType()); 511 }); 512 }); 513 } 514 515 if (isNotBranchOpInterfaceOrReturnLikeOp(op) || 516 isLegalForReturnOpTypeConversionPattern(op, typeConverter, 517 /*returnOpAlwaysLegal*/ true)) 518 return true; 519 520 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { 521 if (!detensorableBranchOps.count(branchOp)) 522 return true; 523 524 for (auto operandIdx : detensorableBranchOps[branchOp]) 525 if (!typeConverter.isLegal( 526 branchOp->getOperand(operandIdx).getType())) 527 return false; 528 529 return true; 530 } 531 532 return false; 533 }); 534 535 patterns.insert<DetensorizeGenericOp>(typeConverter, context); 536 patterns.insert<FunctionNonEntryBlockConversion>(context, typeConverter, 537 blockArgsToDetensor); 538 // Since non-entry block arguments get detensorized, we also need to 539 // update the control flow inside the function to reflect the correct 540 // types. 541 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp, 542 int operandIdx) -> bool { 543 return detensorableBranchOps.count(branchOp) && 544 detensorableBranchOps[branchOp].count(operandIdx); 545 }; 546 547 populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, 548 shouldConvertBranchOperand); 549 550 if (failed( 551 applyFullConversion(getOperation(), target, std::move(patterns)))) 552 signalPassFailure(); 553 554 RewritePatternSet canonPatterns(context); 555 tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context); 556 if (failed(applyPatternsAndFoldGreedily(getOperation(), 557 std::move(canonPatterns)))) 558 signalPassFailure(); 559 } 560 }; 561 } // namespace 562 563 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() { 564 return std::make_unique<LinalgDetensorize>(); 565 } 566