1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/OpDefinition.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 #include <iterator> 19 #include <memory> 20 21 using namespace mlir; 22 using namespace mlir::linalg; 23 24 namespace { 25 /// Defines the criteria a TensorType must follow in order to be considered 26 /// "detensorable". 27 /// 28 /// NOTE: For now, only 0-D are supported. 29 /// 30 /// Returns true if tensorType can be detensored. 31 bool canBeDetensored(TensorType tensorType) { 32 return tensorType.hasRank() && tensorType.getRank() == 0; 33 } 34 35 /// A conversion patttern for detensoring `linalg.generic` ops. 36 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> { 37 public: 38 using OpConversionPattern::OpConversionPattern; 39 LogicalResult 40 matchAndRewrite(GenericOp op, ArrayRef<Value> operands, 41 ConversionPatternRewriter &rewriter) const override { 42 Block *originalBlock = op->getBlock(); 43 44 // Gather some information about the op before inling its region. 45 Block *opEntryBlock = &*op.region().begin(); 46 YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator()); 47 48 // Split the op's region before the op. This way, we have a clear insertion 49 // point in which the op can be inlined. 50 Block *newBlock = originalBlock->splitBlock(op); 51 rewriter.inlineRegionBefore(op.region(), newBlock); 52 // Now that op's region is inlined, the operands of its YieldOp are mapped 53 // to the materialized target values. Therefore, we can replace the op's 54 // uses with those of its YielOp's operands. 55 rewriter.replaceOp(op, yieldOp->getOperands()); 56 57 // No need for these intermediate blocks, merge them into 1. 58 rewriter.mergeBlocks(opEntryBlock, originalBlock, operands); 59 rewriter.mergeBlocks(newBlock, originalBlock, {}); 60 61 rewriter.eraseOp(&*Block::iterator(yieldOp)); 62 63 return success(); 64 } 65 }; 66 67 class DetensorizeTypeConverter : public TypeConverter { 68 public: 69 DetensorizeTypeConverter() { 70 addConversion([](Type type) { return type; }); 71 72 // A TensorType that can be detensored, is converted to the underlying 73 // element type. 74 addConversion([](TensorType tensorType) -> Type { 75 if (canBeDetensored(tensorType)) 76 return tensorType.getElementType(); 77 78 return tensorType; 79 }); 80 81 // A tensor value is detensoried by extracting its element(s). 82 addTargetMaterialization([](OpBuilder &builder, Type type, 83 ValueRange inputs, Location loc) -> Value { 84 return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{}); 85 }); 86 87 // A detensored value is converted back by creating a new tensor from its 88 // element(s). 89 addSourceMaterialization([](OpBuilder &builder, Type type, 90 ValueRange inputs, Location loc) -> Value { 91 auto createNewTensorOp = builder.create<tensor::FromElementsOp>( 92 loc, inputs[0].getType(), inputs[0]); 93 94 // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to 95 // a tensor<dtype> instead. 96 return builder.create<linalg::TensorReshapeOp>( 97 loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{}); 98 }); 99 } 100 }; 101 102 /// Canonicalizes the pattern of the form 103 /// 104 /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> 105 /// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into 106 /// tensor<i32> 107 /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32> 108 /// 109 /// to just %element. 110 struct ExtractFromReshapeFromElements 111 : public OpRewritePattern<tensor::ExtractOp> { 112 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 113 114 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 115 PatternRewriter &rewriter) const final { 116 if (extract.indices().size() != 0) 117 return failure(); 118 119 auto tensorReshape = extract.tensor().getDefiningOp<TensorReshapeOp>(); 120 if (tensorReshape == nullptr) 121 return failure(); 122 123 auto tensorFromElements = 124 tensorReshape.getOperand() 125 .getDefiningOp<mlir::tensor::FromElementsOp>(); 126 if (tensorFromElements == nullptr) 127 return failure(); 128 129 rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); 130 return success(); 131 } 132 }; 133 134 /// @see LinalgDetensorize in Linalg/Passes.td for more details. 135 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> { 136 void runOnFunction() override { 137 auto *context = &getContext(); 138 DetensorizeTypeConverter typeConverter; 139 OwningRewritePatternList patterns; 140 ConversionTarget target(*context); 141 142 target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; }); 143 target.addLegalDialect<linalg::LinalgDialect>(); 144 target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) { 145 // If any of the operands or results cannot be detensored, the op is 146 // considered legal and won't be detensored. 147 return llvm::any_of( 148 op.getShapedOperandTypes(), [](ShapedType shapedType) { 149 assert(shapedType.isa<TensorType>()); 150 return !canBeDetensored(shapedType.cast<TensorType>()); 151 }); 152 }); 153 154 patterns.insert<DetensorizeGenericOp>(typeConverter, context); 155 156 if (failed( 157 applyPartialConversion(getFunction(), target, std::move(patterns)))) 158 signalPassFailure(); 159 160 OwningRewritePatternList canonPatterns; 161 canonPatterns.insert<ExtractFromReshapeFromElements>(context); 162 if (failed(applyPatternsAndFoldGreedily(getFunction(), 163 std::move(canonPatterns)))) 164 signalPassFailure(); 165 166 // TODO Properly handle control flow within function boundaries. 167 } 168 }; 169 } // namespace 170 171 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() { 172 return std::make_unique<LinalgDetensorize>(); 173 } 174