1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 using namespace mlir; 18 using namespace mlir::shape; 19 20 /// Conversion patterns. 21 namespace { 22 class AnyOpConversion : public OpConversionPattern<AnyOp> { 23 public: 24 using OpConversionPattern<AnyOp>::OpConversionPattern; 25 26 LogicalResult 27 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 28 ConversionPatternRewriter &rewriter) const override; 29 }; 30 } // namespace 31 32 LogicalResult 33 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 34 ConversionPatternRewriter &rewriter) const { 35 AnyOp::Adaptor transformed(operands); 36 37 // Replace `any` with its first operand. 38 // Any operand would be a valid substitution. 39 rewriter.replaceOp(op, {transformed.inputs().front()}); 40 return success(); 41 } 42 43 namespace { 44 template <typename SrcOpTy, typename DstOpTy> 45 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 46 public: 47 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 48 49 LogicalResult 50 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 51 ConversionPatternRewriter &rewriter) const override { 52 typename SrcOpTy::Adaptor adaptor(operands); 53 rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.lhs(), adaptor.rhs()); 54 return success(); 55 } 56 }; 57 } // namespace 58 59 namespace { 60 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 61 public: 62 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 63 64 LogicalResult 65 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 66 ConversionPatternRewriter &rewriter) const override; 67 }; 68 } // namespace 69 70 LogicalResult ShapeOfOpConversion::matchAndRewrite( 71 ShapeOfOp op, ArrayRef<Value> operands, 72 ConversionPatternRewriter &rewriter) const { 73 ShapeOfOp::Adaptor transformed(operands); 74 auto loc = op.getLoc(); 75 auto tensorVal = transformed.arg(); 76 auto tensorTy = tensorVal.getType(); 77 78 // For unranked tensors `shape_of` lowers to `scf` and the pattern can be 79 // found in the corresponding pass. 80 if (tensorTy.isa<UnrankedTensorType>()) 81 return failure(); 82 83 // Build values for individual dimensions. 84 SmallVector<Value, 8> dimValues; 85 auto rankedTensorTy = tensorTy.cast<RankedTensorType>(); 86 int64_t rank = rankedTensorTy.getRank(); 87 for (int64_t i = 0; i < rank; i++) { 88 if (rankedTensorTy.isDynamicDim(i)) { 89 auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i); 90 dimValues.push_back(dimVal); 91 } else { 92 int64_t dim = rankedTensorTy.getDimSize(i); 93 auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim); 94 dimValues.push_back(dimVal); 95 } 96 } 97 98 // Materialize extent tensor. 99 Value staticExtentTensor = 100 rewriter.create<TensorFromElementsOp>(loc, dimValues); 101 rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor, 102 op.getType()); 103 return success(); 104 } 105 106 namespace { 107 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 108 public: 109 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 110 111 LogicalResult 112 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 113 ConversionPatternRewriter &rewriter) const override; 114 }; 115 } // namespace 116 117 LogicalResult ConstShapeOpConverter::matchAndRewrite( 118 ConstShapeOp op, ArrayRef<Value> operands, 119 ConversionPatternRewriter &rewriter) const { 120 121 // For now, this lowering supports only extent tensors, not `shape.shape` 122 // types. 123 if (op.getType().isa<ShapeType>()) 124 return failure(); 125 126 auto loc = op.getLoc(); 127 SmallVector<Value, 4> extentOperands; 128 for (auto extent : op.shape()) { 129 extentOperands.push_back( 130 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 131 } 132 Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands); 133 Type indexTy = rewriter.getIndexType(); 134 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 135 rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy); 136 return success(); 137 } 138 139 namespace { 140 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 141 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 142 143 LogicalResult 144 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 145 ConversionPatternRewriter &rewriter) const override; 146 }; 147 } // namespace 148 149 LogicalResult GetExtentOpConverter::matchAndRewrite( 150 GetExtentOp op, ArrayRef<Value> operands, 151 ConversionPatternRewriter &rewriter) const { 152 GetExtentOp::Adaptor transformed(operands); 153 154 // Derive shape extent directly from shape origin if possible. 155 // This circumvents the necessity to materialize the shape in memory. 156 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 157 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), transformed.dim()); 158 return success(); 159 } 160 161 rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(), 162 transformed.shape(), 163 ValueRange{transformed.dim()}); 164 return success(); 165 } 166 167 namespace { 168 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 169 public: 170 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 171 172 LogicalResult 173 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 174 ConversionPatternRewriter &rewriter) const override; 175 }; 176 } // namespace 177 178 LogicalResult 179 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 180 ConversionPatternRewriter &rewriter) const { 181 shape::RankOp::Adaptor transformed(operands); 182 rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0); 183 return success(); 184 } 185 186 namespace { 187 /// Type conversions. 188 class ShapeTypeConverter : public TypeConverter { 189 public: 190 using TypeConverter::convertType; 191 192 ShapeTypeConverter(MLIRContext *ctx) { 193 // Add default pass-through conversion. 194 addConversion([&](Type type) { return type; }); 195 196 addConversion([ctx](SizeType type) { return IndexType::get(ctx); }); 197 addConversion([ctx](ShapeType type) { 198 return RankedTensorType::get({ShapedType::kDynamicSize}, 199 IndexType::get(ctx)); 200 }); 201 } 202 }; 203 } // namespace 204 205 namespace { 206 /// Conversion pass. 207 class ConvertShapeToStandardPass 208 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 209 210 void runOnOperation() override; 211 }; 212 } // namespace 213 214 void ConvertShapeToStandardPass::runOnOperation() { 215 // Setup type conversion. 216 MLIRContext &ctx = getContext(); 217 ShapeTypeConverter typeConverter(&ctx); 218 219 // Setup target legality. 220 ConversionTarget target(ctx); 221 target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>(); 222 target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>(); 223 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 224 return typeConverter.isSignatureLegal(op.getType()) && 225 typeConverter.isLegal(&op.getBody()); 226 }); 227 228 // Setup conversion patterns. 229 OwningRewritePatternList patterns; 230 populateShapeToStandardConversionPatterns(patterns, &ctx); 231 populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter); 232 233 // Apply conversion. 234 auto module = getOperation(); 235 if (failed(applyFullConversion(module, target, patterns))) 236 signalPassFailure(); 237 } 238 239 void mlir::populateShapeToStandardConversionPatterns( 240 OwningRewritePatternList &patterns, MLIRContext *ctx) { 241 // clang-format off 242 patterns.insert< 243 AnyOpConversion, 244 BinaryOpConversion<AddOp, AddIOp>, 245 ConstShapeOpConverter, 246 BinaryOpConversion<MulOp, MulIOp>, 247 GetExtentOpConverter, 248 RankOpConverter, 249 ShapeOfOpConversion>(ctx); 250 // clang-format on 251 } 252 253 std::unique_ptr<OperationPass<ModuleOp>> 254 mlir::createConvertShapeToStandardPass() { 255 return std::make_unique<ConvertShapeToStandardPass>(); 256 } 257