1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/OpDefinition.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 #include <iterator> 19 #include <memory> 20 21 using namespace mlir; 22 using namespace mlir::linalg; 23 24 static Value sourceMaterializationCallback(OpBuilder &builder, Type type, 25 ValueRange inputs, Location loc) { 26 assert(inputs.size() == 1); 27 // A detensored value is converted back by creating a new tensor from its 28 // element(s). 29 auto createNewTensorOp = builder.create<tensor::FromElementsOp>( 30 loc, inputs[0].getType(), inputs[0]); 31 32 // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to 33 // a tensor<dtype> instead. 34 return builder.create<linalg::TensorReshapeOp>( 35 loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{}); 36 } 37 38 namespace { 39 /// Defines the criteria a TensorType must follow in order to be considered 40 /// "detensorable". 41 /// 42 /// NOTE: For now, only 0-D are supported. 43 /// 44 /// Returns true if tensorType can be detensored. 45 bool canBeDetensored(TensorType tensorType) { 46 return tensorType.hasRank() && tensorType.getRank() == 0; 47 } 48 49 /// A conversion patttern for detensoring `linalg.generic` ops. 50 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> { 51 public: 52 using OpConversionPattern::OpConversionPattern; 53 LogicalResult 54 matchAndRewrite(GenericOp op, ArrayRef<Value> operands, 55 ConversionPatternRewriter &rewriter) const override { 56 Block *originalBlock = op->getBlock(); 57 58 // Gather some information about the op before inling its region. 59 Block *opEntryBlock = &*op.region().begin(); 60 YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator()); 61 62 // Split the op's region before the op. This way, we have a clear insertion 63 // point in which the op can be inlined. 64 Block *newBlock = originalBlock->splitBlock(op); 65 rewriter.inlineRegionBefore(op.region(), newBlock); 66 // Now that op's region is inlined, the operands of its YieldOp are mapped 67 // to the materialized target values. Therefore, we can replace the op's 68 // uses with those of its YielOp's operands. 69 rewriter.replaceOp(op, yieldOp->getOperands()); 70 71 // No need for these intermediate blocks, merge them into 1. 72 rewriter.mergeBlocks(opEntryBlock, originalBlock, operands); 73 rewriter.mergeBlocks(newBlock, originalBlock, {}); 74 75 rewriter.eraseOp(&*Block::iterator(yieldOp)); 76 77 return success(); 78 } 79 }; 80 81 /// A conversion pattern for detensoring internal (non-entry) blocks within a 82 /// function. 83 struct FunctionNonEntryBlockConversion : public ConversionPattern { 84 FunctionNonEntryBlockConversion(StringRef functionLikeOpName, 85 MLIRContext *ctx, TypeConverter &converter) 86 : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {} 87 88 LogicalResult 89 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 90 ConversionPatternRewriter &rewriter) const override { 91 rewriter.startRootUpdate(op); 92 93 if (failed(rewriter.convertNonEntryRegionTypes( 94 &mlir::impl::getFunctionBody(op), *typeConverter))) { 95 rewriter.cancelRootUpdate(op); 96 return failure(); 97 } 98 99 rewriter.finalizeRootUpdate(op); 100 return success(); 101 } 102 }; 103 104 class DetensorizeTypeConverter : public TypeConverter { 105 public: 106 DetensorizeTypeConverter() { 107 addConversion([](Type type) { return type; }); 108 109 // A TensorType that can be detensored, is converted to the underlying 110 // element type. 111 addConversion([](TensorType tensorType) -> Type { 112 if (canBeDetensored(tensorType)) 113 return tensorType.getElementType(); 114 115 return tensorType; 116 }); 117 118 // A tensor value is detensoried by extracting its element(s). 119 addTargetMaterialization([](OpBuilder &builder, Type type, 120 ValueRange inputs, Location loc) -> Value { 121 return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{}); 122 }); 123 124 addSourceMaterialization(sourceMaterializationCallback); 125 addArgumentMaterialization(sourceMaterializationCallback); 126 } 127 }; 128 129 /// Canonicalizes the pattern of the form 130 /// 131 /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> 132 /// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into 133 /// tensor<i32> 134 /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32> 135 /// 136 /// to just %element. 137 struct ExtractFromReshapeFromElements 138 : public OpRewritePattern<tensor::ExtractOp> { 139 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 140 141 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 142 PatternRewriter &rewriter) const final { 143 if (extract.indices().size() != 0) 144 return failure(); 145 146 auto tensorReshape = extract.tensor().getDefiningOp<TensorReshapeOp>(); 147 if (tensorReshape == nullptr) 148 return failure(); 149 150 auto tensorFromElements = 151 tensorReshape.getOperand() 152 .getDefiningOp<mlir::tensor::FromElementsOp>(); 153 if (tensorFromElements == nullptr) 154 return failure(); 155 156 rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); 157 return success(); 158 } 159 }; 160 161 /// @see LinalgDetensorize in Linalg/Passes.td for more details. 162 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> { 163 void runOnFunction() override { 164 auto *context = &getContext(); 165 DetensorizeTypeConverter typeConverter; 166 RewritePatternSet patterns(context); 167 ConversionTarget target(*context); 168 169 target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) { 170 // If any of the operands or results cannot be detensored (i.e. they are 171 // all legal according the DetensorizeTypeConverter), the op is considered 172 // legal and won't be detensored. 173 return llvm::any_of(op.getShapedOperandTypes(), 174 [&](ShapedType shapedType) { 175 return typeConverter.isLegal(shapedType); 176 }); 177 }); 178 179 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 180 // A function is legal if all of its non-entry blocks are legal. We don't 181 // legalize the entry block (i.e. the function's signature) since 182 // detensoring can't happen along external calling convention boundaries, 183 // which we conservatively approximate as all function signatures. 184 return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) { 185 return typeConverter.isLegal(block.getArgumentTypes()); 186 }); 187 }); 188 189 target.markUnknownOpDynamicallyLegal([&](Operation *op) { 190 return isNotBranchOpInterfaceOrReturnLikeOp(op) || 191 isLegalForBranchOpInterfaceTypeConversionPattern(op, 192 typeConverter) || 193 isLegalForReturnOpTypeConversionPattern( 194 op, typeConverter, /*returnOpAlwaysLegal*/ true); 195 }); 196 197 patterns.add<DetensorizeGenericOp>(typeConverter, context); 198 patterns.add<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(), 199 context, typeConverter); 200 // Since non-entry block arguments get detensorized, we also need to update 201 // the control flow inside the function to reflect the correct types. 202 populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); 203 204 if (failed(applyFullConversion(getFunction(), target, std::move(patterns)))) 205 signalPassFailure(); 206 207 RewritePatternSet canonPatterns(context); 208 canonPatterns.add<ExtractFromReshapeFromElements>(context); 209 if (failed(applyPatternsAndFoldGreedily(getFunction(), 210 std::move(canonPatterns)))) 211 signalPassFailure(); 212 } 213 }; 214 } // namespace 215 216 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() { 217 return std::make_unique<LinalgDetensorize>(); 218 } 219