1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 using namespace mlir; 18 using namespace mlir::shape; 19 20 namespace { 21 22 /// Generated conversion patterns. 23 #include "ShapeToStandardPatterns.inc" 24 25 /// Conversion patterns. 26 class AnyOpConversion : public OpConversionPattern<AnyOp> { 27 public: 28 using OpConversionPattern<AnyOp>::OpConversionPattern; 29 30 LogicalResult 31 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 32 ConversionPatternRewriter &rewriter) const override { 33 AnyOp::Adaptor transformed(operands); 34 35 // Replace `any` with its first operand. 36 // Any operand would be a valid substitution. 37 rewriter.replaceOp(op, {transformed.inputs().front()}); 38 return success(); 39 } 40 }; 41 42 template <typename SrcOpTy, typename DstOpTy> 43 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 44 public: 45 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 46 47 LogicalResult 48 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 49 ConversionPatternRewriter &rewriter) const override { 50 typename SrcOpTy::Adaptor adaptor(operands); 51 rewriter.replaceOpWithNewOp<DstOpTy>(op.getOperation(), adaptor.lhs(), 52 adaptor.rhs()); 53 return success(); 54 } 55 }; 56 57 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 58 public: 59 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 60 61 LogicalResult 62 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 63 ConversionPatternRewriter &rewriter) const override { 64 ShapeOfOp::Adaptor transformed(operands); 65 auto loc = op.getLoc(); 66 auto tensorVal = transformed.arg(); 67 auto tensorTy = tensorVal.getType(); 68 69 // For unranked tensors `shape_of` lowers to `scf` and the pattern can be 70 // found in the corresponding pass. 71 if (tensorTy.isa<UnrankedTensorType>()) 72 return failure(); 73 74 // Build values for individual dimensions. 75 SmallVector<Value, 8> dimValues; 76 auto rankedTensorTy = tensorTy.cast<RankedTensorType>(); 77 int64_t rank = rankedTensorTy.getRank(); 78 for (int64_t i = 0; i < rank; i++) { 79 if (rankedTensorTy.isDynamicDim(i)) { 80 auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i); 81 dimValues.push_back(dimVal); 82 } else { 83 int64_t dim = rankedTensorTy.getDimSize(i); 84 auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim); 85 dimValues.push_back(dimVal); 86 } 87 } 88 89 // Materialize shape as ranked tensor. 90 rewriter.replaceOpWithNewOp<TensorFromElementsOp>(op.getOperation(), 91 dimValues); 92 return success(); 93 } 94 }; 95 96 class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> { 97 public: 98 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 99 100 LogicalResult 101 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 102 ConversionPatternRewriter &rewriter) const override { 103 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), 104 op.value().getSExtValue()); 105 return success(); 106 } 107 }; 108 109 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 110 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 111 112 LogicalResult 113 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 114 ConversionPatternRewriter &rewriter) const override { 115 GetExtentOp::Adaptor transformed(operands); 116 117 // Derive shape extent directly from shape origin if possible. 118 // This circumvents the necessity to materialize the shape in memory. 119 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 120 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), 121 transformed.dim()); 122 return success(); 123 } 124 125 rewriter.replaceOpWithNewOp<ExtractElementOp>( 126 op, rewriter.getIndexType(), transformed.shape(), 127 ValueRange{transformed.dim()}); 128 return success(); 129 } 130 }; 131 132 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 133 public: 134 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 135 136 LogicalResult 137 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 138 ConversionPatternRewriter &rewriter) const override { 139 shape::RankOp::Adaptor transformed(operands); 140 rewriter.replaceOpWithNewOp<DimOp>(op.getOperation(), transformed.shape(), 141 0); 142 return success(); 143 } 144 }; 145 146 /// Type conversions. 147 class ShapeTypeConverter : public TypeConverter { 148 public: 149 using TypeConverter::convertType; 150 151 ShapeTypeConverter(MLIRContext *ctx) { 152 // Add default pass-through conversion. 153 addConversion([&](Type type) { return type; }); 154 155 addConversion([ctx](SizeType type) { return IndexType::get(ctx); }); 156 addConversion([ctx](ShapeType type) { 157 return RankedTensorType::get({ShapedType::kDynamicSize}, 158 IndexType::get(ctx)); 159 }); 160 } 161 }; 162 163 /// Conversion pass. 164 class ConvertShapeToStandardPass 165 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 166 167 void runOnOperation() override { 168 // Setup type conversion. 169 MLIRContext &ctx = getContext(); 170 ShapeTypeConverter typeConverter(&ctx); 171 172 // Setup target legality. 173 ConversionTarget target(ctx); 174 target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>(); 175 target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>(); 176 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 177 return typeConverter.isSignatureLegal(op.getType()) && 178 typeConverter.isLegal(&op.getBody()); 179 }); 180 181 // Setup conversion patterns. 182 OwningRewritePatternList patterns; 183 populateShapeToStandardConversionPatterns(patterns, &ctx); 184 populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter); 185 186 // Apply conversion. 187 auto module = getOperation(); 188 if (failed(applyFullConversion(module, target, patterns))) 189 signalPassFailure(); 190 } 191 }; 192 193 } // namespace 194 195 void mlir::populateShapeToStandardConversionPatterns( 196 OwningRewritePatternList &patterns, MLIRContext *ctx) { 197 populateWithGenerated(ctx, &patterns); 198 // clang-format off 199 patterns.insert< 200 AnyOpConversion, 201 BinaryOpConversion<AddOp, AddIOp>, 202 BinaryOpConversion<MulOp, MulIOp>, 203 ConstSizeOpConverter, 204 GetExtentOpConverter, 205 RankOpConverter, 206 ShapeOfOpConversion>(ctx); 207 // clang-format on 208 } 209 210 std::unique_ptr<OperationPass<ModuleOp>> 211 mlir::createConvertShapeToStandardPass() { 212 return std::make_unique<ConvertShapeToStandardPass>(); 213 } 214