1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 using namespace mlir; 18 using namespace mlir::shape; 19 20 namespace { 21 22 /// Generated conversion patterns. 23 #include "ShapeToStandardPatterns.inc" 24 25 /// Conversion patterns. 26 template <typename SrcOpTy, typename DstOpTy> 27 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 28 public: 29 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 30 31 LogicalResult 32 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 33 ConversionPatternRewriter &rewriter) const override { 34 typename SrcOpTy::Adaptor adaptor(operands); 35 rewriter.replaceOpWithNewOp<DstOpTy>(op.getOperation(), adaptor.lhs(), 36 adaptor.rhs()); 37 return success(); 38 } 39 }; 40 41 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 42 public: 43 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 44 45 LogicalResult 46 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 47 ConversionPatternRewriter &rewriter) const override { 48 ShapeOfOp::Adaptor transformed(operands); 49 auto loc = op.getLoc(); 50 auto tensorVal = transformed.arg(); 51 auto tensorTy = tensorVal.getType(); 52 53 // For unranked tensors `shape_of` lowers to `scf` and the pattern can be 54 // found in the corresponding pass. 55 if (tensorTy.isa<UnrankedTensorType>()) 56 return failure(); 57 58 // Build values for individual dimensions. 59 SmallVector<Value, 8> dimValues; 60 auto rankedTensorTy = tensorTy.cast<RankedTensorType>(); 61 int64_t rank = rankedTensorTy.getRank(); 62 for (int64_t i = 0; i < rank; i++) { 63 if (rankedTensorTy.isDynamicDim(i)) { 64 auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i); 65 dimValues.push_back(dimVal); 66 } else { 67 int64_t dim = rankedTensorTy.getDimSize(i); 68 auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim); 69 dimValues.push_back(dimVal); 70 } 71 } 72 73 // Materialize shape as ranked tensor. 74 rewriter.replaceOpWithNewOp<TensorFromElementsOp>(op.getOperation(), 75 dimValues); 76 return success(); 77 } 78 }; 79 80 class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> { 81 public: 82 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 83 84 LogicalResult 85 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 86 ConversionPatternRewriter &rewriter) const override { 87 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), 88 op.value().getSExtValue()); 89 return success(); 90 } 91 }; 92 93 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 94 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 95 96 LogicalResult 97 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 98 ConversionPatternRewriter &rewriter) const override { 99 GetExtentOp::Adaptor transformed(operands); 100 101 // Derive shape extent directly from shape origin if possible. 102 // This circumvents the necessity to materialize the shape in memory. 103 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 104 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), 105 transformed.dim()); 106 return success(); 107 } 108 109 rewriter.replaceOpWithNewOp<ExtractElementOp>( 110 op, rewriter.getIndexType(), transformed.shape(), 111 ValueRange{transformed.dim()}); 112 return success(); 113 } 114 }; 115 116 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 117 public: 118 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 119 120 LogicalResult 121 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 122 ConversionPatternRewriter &rewriter) const override { 123 shape::RankOp::Adaptor transformed(operands); 124 rewriter.replaceOpWithNewOp<DimOp>(op.getOperation(), transformed.shape(), 125 0); 126 return success(); 127 } 128 }; 129 130 /// Type conversions. 131 class ShapeTypeConverter : public TypeConverter { 132 public: 133 using TypeConverter::convertType; 134 135 ShapeTypeConverter(MLIRContext *ctx) { 136 // Add default pass-through conversion. 137 addConversion([&](Type type) { return type; }); 138 139 addConversion([ctx](SizeType type) { return IndexType::get(ctx); }); 140 addConversion([ctx](ShapeType type) { 141 return RankedTensorType::get({ShapedType::kDynamicSize}, 142 IndexType::get(ctx)); 143 }); 144 } 145 }; 146 147 /// Conversion pass. 148 class ConvertShapeToStandardPass 149 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 150 151 void runOnOperation() override { 152 // Setup type conversion. 153 MLIRContext &ctx = getContext(); 154 ShapeTypeConverter typeConverter(&ctx); 155 156 // Setup target legality. 157 ConversionTarget target(ctx); 158 target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>(); 159 target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>(); 160 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 161 return typeConverter.isSignatureLegal(op.getType()) && 162 typeConverter.isLegal(&op.getBody()); 163 }); 164 165 // Setup conversion patterns. 166 OwningRewritePatternList patterns; 167 populateShapeToStandardConversionPatterns(patterns, &ctx); 168 populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter); 169 170 // Apply conversion. 171 auto module = getOperation(); 172 if (failed(applyFullConversion(module, target, patterns))) 173 signalPassFailure(); 174 } 175 }; 176 177 } // namespace 178 179 void mlir::populateShapeToStandardConversionPatterns( 180 OwningRewritePatternList &patterns, MLIRContext *ctx) { 181 populateWithGenerated(ctx, &patterns); 182 // clang-format off 183 patterns.insert< 184 BinaryOpConversion<AddOp, AddIOp>, 185 BinaryOpConversion<MulOp, MulIOp>, 186 ConstSizeOpConverter, 187 GetExtentOpConverter, 188 RankOpConverter, 189 ShapeOfOpConversion>(ctx); 190 // clang-format on 191 } 192 193 std::unique_ptr<OperationPass<ModuleOp>> 194 mlir::createConvertShapeToStandardPass() { 195 return std::make_unique<ConvertShapeToStandardPass>(); 196 } 197