1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 using namespace mlir; 18 using namespace mlir::shape; 19 20 /// Conversion patterns. 21 namespace { 22 class AnyOpConversion : public OpConversionPattern<AnyOp> { 23 public: 24 using OpConversionPattern<AnyOp>::OpConversionPattern; 25 26 LogicalResult 27 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 28 ConversionPatternRewriter &rewriter) const override; 29 }; 30 } // namespace 31 32 LogicalResult 33 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 34 ConversionPatternRewriter &rewriter) const { 35 AnyOp::Adaptor transformed(operands); 36 37 // Replace `any` with its first operand. 38 // Any operand would be a valid substitution. 39 rewriter.replaceOp(op, {transformed.inputs().front()}); 40 return success(); 41 } 42 43 namespace { 44 template <typename SrcOpTy, typename DstOpTy> 45 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 46 public: 47 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 48 49 LogicalResult 50 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 51 ConversionPatternRewriter &rewriter) const override { 52 typename SrcOpTy::Adaptor adaptor(operands); 53 rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.lhs(), adaptor.rhs()); 54 return success(); 55 } 56 }; 57 } // namespace 58 59 namespace { 60 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 61 public: 62 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 63 64 LogicalResult 65 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 66 ConversionPatternRewriter &rewriter) const override { 67 68 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 69 return success(); 70 } 71 }; 72 } // namespace 73 74 namespace { 75 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 76 public: 77 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 78 79 LogicalResult 80 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 81 ConversionPatternRewriter &rewriter) const override; 82 }; 83 } // namespace 84 85 LogicalResult ShapeOfOpConversion::matchAndRewrite( 86 ShapeOfOp op, ArrayRef<Value> operands, 87 ConversionPatternRewriter &rewriter) const { 88 ShapeOfOp::Adaptor transformed(operands); 89 auto loc = op.getLoc(); 90 auto tensorVal = transformed.arg(); 91 auto tensorTy = tensorVal.getType(); 92 93 // For unranked tensors `shape_of` lowers to `scf` and the pattern can be 94 // found in the corresponding pass. 95 if (tensorTy.isa<UnrankedTensorType>()) 96 return failure(); 97 98 // Build values for individual dimensions. 99 SmallVector<Value, 8> dimValues; 100 auto rankedTensorTy = tensorTy.cast<RankedTensorType>(); 101 int64_t rank = rankedTensorTy.getRank(); 102 for (int64_t i = 0; i < rank; i++) { 103 if (rankedTensorTy.isDynamicDim(i)) { 104 auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i); 105 dimValues.push_back(dimVal); 106 } else { 107 int64_t dim = rankedTensorTy.getDimSize(i); 108 auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim); 109 dimValues.push_back(dimVal); 110 } 111 } 112 113 // Materialize extent tensor. 114 Value staticExtentTensor = 115 rewriter.create<TensorFromElementsOp>(loc, dimValues); 116 rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor, 117 op.getType()); 118 return success(); 119 } 120 121 namespace { 122 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 123 public: 124 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 125 126 LogicalResult 127 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 128 ConversionPatternRewriter &rewriter) const override; 129 }; 130 } // namespace 131 132 LogicalResult ConstShapeOpConverter::matchAndRewrite( 133 ConstShapeOp op, ArrayRef<Value> operands, 134 ConversionPatternRewriter &rewriter) const { 135 136 // For now, this lowering supports only extent tensors, not `shape.shape` 137 // types. 138 if (op.getType().isa<ShapeType>()) 139 return failure(); 140 141 auto loc = op.getLoc(); 142 SmallVector<Value, 4> extentOperands; 143 for (auto extent : op.shape()) { 144 extentOperands.push_back( 145 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 146 } 147 Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands); 148 Type indexTy = rewriter.getIndexType(); 149 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 150 rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy); 151 return success(); 152 } 153 154 namespace { 155 class ToExtentTensorOpConversion 156 : public OpConversionPattern<ToExtentTensorOp> { 157 public: 158 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 159 160 LogicalResult 161 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 162 ConversionPatternRewriter &rewriter) const override { 163 ToExtentTensorOpAdaptor adaptor(operands); 164 165 if (!adaptor.input().getType().isa<RankedTensorType>()) 166 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 167 168 rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(), 169 op.getType()); 170 return success(); 171 } 172 }; 173 } // namespace 174 175 namespace { 176 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 177 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 178 179 LogicalResult 180 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 181 ConversionPatternRewriter &rewriter) const override; 182 }; 183 } // namespace 184 185 LogicalResult GetExtentOpConverter::matchAndRewrite( 186 GetExtentOp op, ArrayRef<Value> operands, 187 ConversionPatternRewriter &rewriter) const { 188 GetExtentOp::Adaptor transformed(operands); 189 190 // Derive shape extent directly from shape origin if possible. 191 // This circumvents the necessity to materialize the shape in memory. 192 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 193 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), transformed.dim()); 194 return success(); 195 } 196 197 rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(), 198 transformed.shape(), 199 ValueRange{transformed.dim()}); 200 return success(); 201 } 202 203 namespace { 204 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 205 public: 206 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 207 208 LogicalResult 209 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 210 ConversionPatternRewriter &rewriter) const override; 211 }; 212 } // namespace 213 214 LogicalResult 215 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 216 ConversionPatternRewriter &rewriter) const { 217 shape::RankOp::Adaptor transformed(operands); 218 rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0); 219 return success(); 220 } 221 222 namespace { 223 /// Conversion pass. 224 class ConvertShapeToStandardPass 225 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 226 227 void runOnOperation() override; 228 }; 229 } // namespace 230 231 void ConvertShapeToStandardPass::runOnOperation() { 232 // Setup target legality. 233 MLIRContext &ctx = getContext(); 234 ConversionTarget target(ctx); 235 target.addLegalDialect<StandardOpsDialect>(); 236 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>(); 237 238 // Setup conversion patterns. 239 OwningRewritePatternList patterns; 240 populateShapeToStandardConversionPatterns(patterns, &ctx); 241 242 // Apply conversion. 243 auto module = getOperation(); 244 if (failed(applyFullConversion(module, target, patterns))) 245 signalPassFailure(); 246 } 247 248 void mlir::populateShapeToStandardConversionPatterns( 249 OwningRewritePatternList &patterns, MLIRContext *ctx) { 250 // clang-format off 251 patterns.insert< 252 AnyOpConversion, 253 BinaryOpConversion<AddOp, AddIOp>, 254 ConstShapeOpConverter, 255 BinaryOpConversion<MulOp, MulIOp>, 256 ConstSizeOpConversion, 257 GetExtentOpConverter, 258 RankOpConverter, 259 ShapeOfOpConversion, 260 ToExtentTensorOpConversion>(ctx); 261 // clang-format on 262 } 263 264 std::unique_ptr<OperationPass<ModuleOp>> 265 mlir::createConvertShapeToStandardPass() { 266 return std::make_unique<ConvertShapeToStandardPass>(); 267 } 268