1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/OpDefinition.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 #include <iterator> 19 #include <memory> 20 21 using namespace mlir; 22 using namespace mlir::linalg; 23 24 static Value sourceMaterializationCallback(OpBuilder &builder, Type type, 25 ValueRange inputs, Location loc) { 26 assert(inputs.size() == 1); 27 // A detensored value is converted back by creating a new tensor from its 28 // element(s). 29 auto createNewTensorOp = builder.create<tensor::FromElementsOp>( 30 loc, inputs[0].getType(), inputs[0]); 31 32 // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to 33 // a tensor<dtype> instead. 34 return builder.create<linalg::TensorCollapseShapeOp>( 35 loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{}); 36 } 37 38 namespace { 39 /// Defines the criteria a TensorType must follow in order to be considered 40 /// "detensorable". 41 /// 42 /// NOTE: For now, only 0-D tensors are supported. 43 /// 44 /// Returns true if tensorType can be detensored. 45 bool canBeDetensored(TensorType tensorType) { 46 return tensorType.hasRank() && tensorType.getRank() == 0; 47 } 48 49 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { 50 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op); 51 return genericOp && llvm::all_of(genericOp.getShapedOperandTypes(), 52 [&](ShapedType shapedType) { 53 return !typeConverter.isLegal(shapedType); 54 }); 55 } 56 57 /// A conversion patttern for detensoring `linalg.generic` ops. 58 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> { 59 public: 60 using OpConversionPattern::OpConversionPattern; 61 LogicalResult 62 matchAndRewrite(GenericOp op, ArrayRef<Value> operands, 63 ConversionPatternRewriter &rewriter) const override { 64 Block *originalBlock = op->getBlock(); 65 66 // Gather some information about the op before inling its region. 67 Block *opEntryBlock = &*op.region().begin(); 68 YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator()); 69 70 // Split the op's region before the op. This way, we have a clear insertion 71 // point in which the op can be inlined. 72 Block *newBlock = originalBlock->splitBlock(op); 73 rewriter.inlineRegionBefore(op.region(), newBlock); 74 // Now that op's region is inlined, the operands of its YieldOp are mapped 75 // to the materialized target values. Therefore, we can replace the op's 76 // uses with those of its YielOp's operands. 77 rewriter.replaceOp(op, yieldOp->getOperands()); 78 79 // No need for these intermediate blocks, merge them into 1. 80 rewriter.mergeBlocks(opEntryBlock, originalBlock, operands); 81 rewriter.mergeBlocks(newBlock, originalBlock, {}); 82 83 rewriter.eraseOp(&*Block::iterator(yieldOp)); 84 85 return success(); 86 } 87 }; 88 89 /// A conversion pattern for detensoring internal (non-entry) blocks within a 90 /// function. 91 struct FunctionNonEntryBlockConversion : public ConversionPattern { 92 FunctionNonEntryBlockConversion(StringRef functionLikeOpName, 93 MLIRContext *ctx, TypeConverter &converter, 94 DenseSet<BlockArgument> blockArgsToDetensor) 95 : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx), 96 blockArgsToDetensor(blockArgsToDetensor) {} 97 98 LogicalResult 99 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 100 ConversionPatternRewriter &rewriter) const override { 101 rewriter.startRootUpdate(op); 102 Region ®ion = function_like_impl::getFunctionBody(op); 103 SmallVector<TypeConverter::SignatureConversion, 2> conversions; 104 105 for (Block &block : llvm::drop_begin(region, 1)) { 106 conversions.emplace_back(block.getNumArguments()); 107 TypeConverter::SignatureConversion &back = conversions.back(); 108 109 for (BlockArgument blockArgument : block.getArguments()) { 110 int idx = blockArgument.getArgNumber(); 111 112 if (blockArgsToDetensor.count(blockArgument)) 113 back.addInputs(idx, {getTypeConverter()->convertType( 114 block.getArgumentTypes()[idx])}); 115 else 116 back.addInputs(idx, {block.getArgumentTypes()[idx]}); 117 } 118 } 119 120 if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, 121 conversions))) { 122 rewriter.cancelRootUpdate(op); 123 return failure(); 124 } 125 126 rewriter.finalizeRootUpdate(op); 127 return success(); 128 } 129 130 private: 131 const DenseSet<BlockArgument> blockArgsToDetensor; 132 }; 133 134 class DetensorizeTypeConverter : public TypeConverter { 135 public: 136 DetensorizeTypeConverter() { 137 addConversion([](Type type) { return type; }); 138 139 // A TensorType that can be detensored, is converted to the underlying 140 // element type. 141 addConversion([](TensorType tensorType) -> Type { 142 if (canBeDetensored(tensorType)) 143 return tensorType.getElementType(); 144 145 return tensorType; 146 }); 147 148 // A tensor value is detensoried by extracting its element(s). 149 addTargetMaterialization([](OpBuilder &builder, Type type, 150 ValueRange inputs, Location loc) -> Value { 151 return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{}); 152 }); 153 154 addSourceMaterialization(sourceMaterializationCallback); 155 addArgumentMaterialization(sourceMaterializationCallback); 156 } 157 }; 158 159 /// Canonicalizes the pattern of the form 160 /// 161 /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> 162 /// %reshaped_tensor = linalg.tensor_collapse_shape %tensor [] 163 /// : tensor<1xi32> into tensor<i32> 164 /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32> 165 /// 166 /// to just %element. 167 struct ExtractFromReshapeFromElements 168 : public OpRewritePattern<tensor::ExtractOp> { 169 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 170 171 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 172 PatternRewriter &rewriter) const final { 173 if (!extract.indices().empty()) 174 return failure(); 175 176 auto tensorReshape = 177 extract.tensor().getDefiningOp<TensorCollapseShapeOp>(); 178 if (tensorReshape == nullptr) 179 return failure(); 180 181 auto tensorFromElements = 182 tensorReshape.getOperand() 183 .getDefiningOp<mlir::tensor::FromElementsOp>(); 184 if (tensorFromElements == nullptr) 185 return failure(); 186 187 rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); 188 return success(); 189 } 190 }; 191 192 /// @see LinalgDetensorize in Linalg/Passes.td for more details. 193 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> { 194 LinalgDetensorize() = default; 195 LinalgDetensorize(const LinalgDetensorize &pass) {} 196 197 class CostModel { 198 public: 199 virtual ~CostModel() = default; 200 201 /// A cost model algorithm computes the following outputs: 202 /// 203 /// - opsToDetensor: the list of linalg ops that should be 204 /// detensored. 205 /// 206 /// - blockArgsToDetensor: since the operands and results of detensored 207 /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come 208 /// from a BB argument and a linalg op's output can be passed to successor 209 /// BBs), we need to maintain the sub-set of arguments that should be 210 /// detensored (i.e. converted by typeConverter) for each affected BB. 211 /// 212 /// Example: 213 /// 214 /// For the following snippet: 215 /// ... 216 /// ^bb1(%6: tensor<i32>, %9: tensor<i32>): 217 /// %7 = linalg.init_tensor [] : tensor<i32> 218 /// %8 = linalg.generic #attrs 219 /// ins(%6, %6 : tensor<i32>, tensor<i32>) 220 /// outs(%7 : tensor<i32>) { 221 /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): 222 /// %9 = addi %arg0, %arg1 : i32 223 /// linalg.yield %9 : i32 224 /// } -> tensor<i32> 225 /// %10 = "some.op"(%9) 226 /// br ^bb2(%8 : tensor<i32>) 227 /// ... 228 /// 229 /// if the cost model decides that the linalg.generic op should be 230 /// detensored, then: 231 /// - opsToDetensor should be = {linalg.generic{add}}. 232 /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}. 233 virtual void compute(FuncOp func, DetensorizeTypeConverter typeConverter, 234 DenseSet<Operation *> &opsToDetensor, 235 DenseSet<BlockArgument> &blockArgsToDetensor) = 0; 236 237 /// From the blockArgsToDetensor set computed by a CostModel 238 /// implementation, this method computes the corresponding branch op 239 /// detensoring. The result is a map from a branch op to a subset of indices 240 /// of its operands. The indices specify which of the branch op's operands 241 /// should be detensored. 242 /// 243 /// For the previous example, this method would compute: {bb2 -> {0}}. 244 static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring( 245 const DenseSet<BlockArgument> &blockArgsToDetensor) { 246 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 247 248 for (auto blockArgumentElem : blockArgsToDetensor) { 249 Block *block = blockArgumentElem.getOwner(); 250 251 for (PredecessorIterator pred = block->pred_begin(); 252 pred != block->pred_end(); ++pred) { 253 BranchOpInterface terminator = 254 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 255 auto blockOperands = 256 terminator.getSuccessorOperands(pred.getSuccessorIndex()); 257 258 if (!blockOperands || blockOperands->empty()) 259 continue; 260 261 detensorableBranchOps[terminator].insert( 262 blockOperands->getBeginOperandIndex() + 263 blockArgumentElem.getArgNumber()); 264 } 265 } 266 267 return detensorableBranchOps; 268 } 269 }; 270 271 /// Detensorize linalg ops involved in control-flow within a function. 272 /// 273 /// This model starts from CondBranchOps within a function. For each cond_br, 274 /// the model then walks the use-def chain for the branch's condition 275 /// backwards in order to understand where the condition's value comes from. 276 /// If the condition value is (indirectly) computed by a linalg op that can be 277 /// detensored, the model then continues walking the use-def chain in order to 278 /// understand where the linalg op's operands come from. This leads to 279 /// discovering a "detensoring component". A detensoring component is the set 280 /// of operations + block arguments that are involved in control-flow AND can 281 /// be detensored. 282 /// 283 /// For examples where this model succeeds to discover a detensoring 284 /// component, see: 285 /// - test/Dialect/Linalg/detensorize_while.mlir 286 /// - test/Dialect/Linalg/detesorize_while_pure_cf.mlir. 287 /// 288 /// For an example where this model marks control-flow as "non-detensorable", 289 /// see: 290 /// - test/Dialect/Linalg/detensorize_while_failure.mlir 291 class PureControlFlowDetectionModel : public CostModel { 292 public: 293 void compute(FuncOp func, DetensorizeTypeConverter typeConverter, 294 DenseSet<Operation *> &opsToDetensor, 295 DenseSet<BlockArgument> &blockArgsToDetensor) override { 296 SmallVector<Value> workList; 297 298 func.walk( 299 [&](CondBranchOp condBr) { workList.push_back(condBr.condition()); }); 300 301 DenseSet<Value> visitedValues; 302 DenseSet<Operation *> visitedOps; 303 304 // For a (to-be-detesored) value, check if it "escapes" the block by being 305 // passed to terminator. If it does, then workList is updated with the 306 // corresponding argument to the successor block. 307 auto updateWorkListWithSuccessorArguments = 308 [&](Value value, BranchOpInterface terminator) { 309 if (!terminator) 310 return; 311 312 for (auto operandIdx : 313 llvm::seq<unsigned>(0, terminator->getOperands().size())) { 314 Value operand = terminator->getOperand(operandIdx); 315 316 if (operand == value) { 317 auto succBlockArg = 318 terminator.getSuccessorBlockArgument(operandIdx); 319 320 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg)) 321 workList.push_back(*succBlockArg); 322 } 323 } 324 }; 325 326 while (!workList.empty()) { 327 Value currentItem = workList.pop_back_val(); 328 329 if (!visitedValues.insert(currentItem).second) 330 continue; 331 332 // 1 - Look forward: 333 // 1.1 - If currentItem escapes to one or more successors, add 334 // the corresponding successor arguments to workList. 335 updateWorkListWithSuccessorArguments( 336 currentItem, dyn_cast<BranchOpInterface>( 337 currentItem.getParentBlock()->getTerminator())); 338 339 // 1.2 - For each user of currentItem, add the defined values to 340 // workList. This way, the user ops can be inspected later if they are 341 // detensorable and if so, their operands will be added to workList to 342 // potentially discover other parts of the detensorable component. 343 for (auto *user : currentItem.getUsers()) 344 for (Value result : user->getResults()) 345 workList.push_back(result); 346 347 // 2 - Look backward: 348 // 2.1 - The current item is defined by a block argument. If the owner 349 // block is a non-entry one, then: 350 // * Add the argument to blockArgsToDetensor. 351 // * Walk the use-def chain backwards to add each predecessor's 352 // terminator-operands corresponding to currentItem to workList. 353 if (currentItem.dyn_cast<BlockArgument>()) { 354 BlockArgument currentItemBlockArgument = 355 currentItem.cast<BlockArgument>(); 356 Block *ownerBlock = currentItemBlockArgument.getOwner(); 357 358 // Function arguments are not detensored/converted. 359 if (&*ownerBlock->getParent()->begin() == ownerBlock) 360 continue; 361 362 // This inner-block argument is involved in control-flow, it should be 363 // detensored. 364 blockArgsToDetensor.insert(currentItemBlockArgument); 365 366 for (PredecessorIterator pred = ownerBlock->pred_begin(); 367 pred != ownerBlock->pred_end(); ++pred) { 368 BranchOpInterface terminator = 369 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 370 371 // TODO: For now, we give up if any of the control-flow components 372 // in a function is not detensorable. Fix that. 373 if (!terminator) { 374 opsToDetensor.clear(); 375 blockArgsToDetensor.clear(); 376 return; 377 } 378 379 auto ownerBlockOperands = 380 terminator.getSuccessorOperands(pred.getSuccessorIndex()); 381 382 if (!ownerBlockOperands || ownerBlockOperands->empty()) 383 continue; 384 385 // For each predecessor, add the value it passes to that argument to 386 // workList to find out how it's computed. 387 workList.push_back( 388 ownerBlockOperands 389 .getValue()[currentItemBlockArgument.getArgNumber()]); 390 } 391 392 continue; 393 } 394 395 Operation *currentItemDefiningOp = currentItem.getDefiningOp(); 396 397 if (!visitedOps.insert(currentItemDefiningOp).second) 398 continue; 399 400 // 2.2 - The current item is computed by a GenericOp. If the op should 401 // be detensored, then: 402 // * Add it to opsToDetensor. 403 // * Add its operands to workList to discover other parts of the 404 // potentially detensorable component. 405 if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) { 406 // The op was encountered already, no need to inspect it again. 407 if (opsToDetensor.count(genericOp)) 408 continue; 409 410 // TODO: For now, we give up if any of the control-flow components 411 // in a function is not detensorable. Fix that. 412 if (!shouldBeDetensored(genericOp, typeConverter)) { 413 opsToDetensor.clear(); 414 blockArgsToDetensor.clear(); 415 return; 416 } 417 418 opsToDetensor.insert(genericOp); 419 420 for (Value genericOpOperand : genericOp.inputs()) 421 workList.push_back(genericOpOperand); 422 423 continue; 424 } 425 426 // 2.3 - The current item is the result of a FromElementsOp, it will be 427 // trivially detensored later as part of canonicalization patterns 428 // applied at the end of detensoring. 429 // 430 // Note: No need to check whether the result type of this op is 431 // detensorable since if it wasn't we wouldn't reach that point in the 432 // work list. 433 if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp)) 434 continue; 435 436 // 2.4 - The current item is the result of a scalar op, add all its 437 // operands to the work list. 438 if (llvm::all_of( 439 currentItemDefiningOp->getResultTypes(), 440 [&](Type resultType) { return resultType.isIntOrFloat(); })) 441 for (Value scalarOpOperand : currentItemDefiningOp->getOperands()) 442 workList.push_back(scalarOpOperand); 443 } 444 } 445 }; 446 447 /// Detensorize everything that can detensored. 448 class AggressiveDetensoringModel : public CostModel { 449 public: 450 void compute(FuncOp func, DetensorizeTypeConverter typeConverter, 451 DenseSet<Operation *> &opsToDetensor, 452 DenseSet<BlockArgument> &blockArgsToDetensor) override { 453 func.walk([&](GenericOp genericOp) { 454 if (shouldBeDetensored(genericOp, typeConverter)) 455 opsToDetensor.insert(genericOp); 456 }); 457 458 for (Block &block : llvm::drop_begin(func.getBody(), 1)) 459 for (BlockArgument blockArgument : block.getArguments()) 460 blockArgsToDetensor.insert(blockArgument); 461 } 462 }; 463 464 void runOnFunction() override { 465 MLIRContext *context = &getContext(); 466 DetensorizeTypeConverter typeConverter; 467 RewritePatternSet patterns(context); 468 ConversionTarget target(*context); 469 DenseSet<Operation *> opsToDetensor; 470 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 471 DenseSet<BlockArgument> blockArgsToDetensor; 472 473 if (aggressiveMode.getValue()) { 474 AggressiveDetensoringModel costModel; 475 costModel.compute(getFunction(), typeConverter, opsToDetensor, 476 blockArgsToDetensor); 477 478 } else { 479 PureControlFlowDetectionModel costModel; 480 costModel.compute(getFunction(), typeConverter, opsToDetensor, 481 blockArgsToDetensor); 482 } 483 484 detensorableBranchOps = 485 CostModel::computeBranchOpDetensoring(blockArgsToDetensor); 486 487 target.addDynamicallyLegalOp<GenericOp>( 488 [&](GenericOp op) { return !opsToDetensor.count(op); }); 489 490 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 491 // A function is legal if all of its non-entry blocks are legal. We 492 // don't legalize the entry block (i.e. the function's signature) 493 // since detensoring can't happen along external calling convention 494 // boundaries, which we conservatively approximate as all function 495 // signatures. 496 return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) { 497 if (llvm::any_of(blockArgsToDetensor, [&](BlockArgument blockArgument) { 498 return blockArgument.getOwner() == &block && 499 !typeConverter.isLegal(blockArgument.getType()); 500 })) { 501 return false; 502 } 503 return true; 504 }); 505 }); 506 507 target.markUnknownOpDynamicallyLegal([&](Operation *op) { 508 if (isNotBranchOpInterfaceOrReturnLikeOp(op) || 509 isLegalForReturnOpTypeConversionPattern(op, typeConverter, 510 /*returnOpAlwaysLegal*/ true)) 511 return true; 512 513 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { 514 if (!detensorableBranchOps.count(branchOp)) 515 return true; 516 517 for (auto operandIdx : detensorableBranchOps[branchOp]) 518 if (!typeConverter.isLegal( 519 branchOp->getOperand(operandIdx).getType())) 520 return false; 521 522 return true; 523 } 524 525 return false; 526 }); 527 528 patterns.insert<DetensorizeGenericOp>(typeConverter, context); 529 patterns.insert<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(), 530 context, typeConverter, 531 blockArgsToDetensor); 532 // Since non-entry block arguments get detensorized, we also need to 533 // update the control flow inside the function to reflect the correct 534 // types. 535 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp, 536 int operandIdx) -> bool { 537 return detensorableBranchOps.count(branchOp) && 538 detensorableBranchOps[branchOp].count(operandIdx); 539 }; 540 541 populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, 542 shouldConvertBranchOperand); 543 544 if (failed(applyFullConversion(getFunction(), target, std::move(patterns)))) 545 signalPassFailure(); 546 547 RewritePatternSet canonPatterns(context); 548 canonPatterns.add<ExtractFromReshapeFromElements>(context); 549 if (failed(applyPatternsAndFoldGreedily(getFunction(), 550 std::move(canonPatterns)))) 551 signalPassFailure(); 552 } 553 554 Option<bool> aggressiveMode{ 555 *this, "aggressive-mode", 556 llvm::cl::desc("Detensorize all ops that qualify for detensoring along " 557 "with branch operands and basic-block arguments.")}; 558 }; 559 } // namespace 560 561 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() { 562 return std::make_unique<LinalgDetensorize>(); 563 } 564