1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 using namespace mlir; 18 using namespace mlir::shape; 19 20 /// Conversion patterns. 21 namespace { 22 class AnyOpConversion : public OpConversionPattern<AnyOp> { 23 public: 24 using OpConversionPattern<AnyOp>::OpConversionPattern; 25 26 LogicalResult 27 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 28 ConversionPatternRewriter &rewriter) const override; 29 }; 30 } // namespace 31 32 LogicalResult 33 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 34 ConversionPatternRewriter &rewriter) const { 35 AnyOp::Adaptor transformed(operands); 36 37 // Replace `any` with its first operand. 38 // Any operand would be a valid substitution. 39 rewriter.replaceOp(op, {transformed.inputs().front()}); 40 return success(); 41 } 42 43 namespace { 44 template <typename SrcOpTy, typename DstOpTy> 45 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 46 public: 47 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 48 49 LogicalResult 50 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 51 ConversionPatternRewriter &rewriter) const override { 52 typename SrcOpTy::Adaptor transformed(operands); 53 54 // For now, only error-free types are supported by this lowering. 55 if (op.getType().template isa<SizeType>()) 56 return failure(); 57 58 rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(), 59 transformed.rhs()); 60 return success(); 61 } 62 }; 63 } // namespace 64 65 namespace { 66 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 67 public: 68 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 69 70 LogicalResult 71 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 72 ConversionPatternRewriter &rewriter) const override { 73 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 74 return success(); 75 } 76 }; 77 } // namespace 78 79 namespace { 80 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 81 public: 82 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 83 84 LogicalResult 85 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 86 ConversionPatternRewriter &rewriter) const override; 87 }; 88 } // namespace 89 90 LogicalResult ShapeOfOpConversion::matchAndRewrite( 91 ShapeOfOp op, ArrayRef<Value> operands, 92 ConversionPatternRewriter &rewriter) const { 93 94 // For now, only error-free types are supported by this lowering. 95 if (op.getType().isa<ShapeType>()) 96 return failure(); 97 98 // For unranked tensors `shape_of` lowers to `scf` and the pattern can be 99 // found in the corresponding pass. 100 ShapeOfOp::Adaptor transformed(operands); 101 Value tensorVal = transformed.arg(); 102 Type tensorTy = tensorVal.getType(); 103 if (tensorTy.isa<UnrankedTensorType>()) 104 return failure(); 105 106 // Build values for individual dimensions. 107 SmallVector<Value, 8> dimValues; 108 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 109 int64_t rank = rankedTensorTy.getRank(); 110 auto loc = op.getLoc(); 111 for (int64_t i = 0; i < rank; i++) { 112 if (rankedTensorTy.isDynamicDim(i)) { 113 Value dimVal = rewriter.create<DimOp>(loc, tensorVal, i); 114 dimValues.push_back(dimVal); 115 } else { 116 int64_t dim = rankedTensorTy.getDimSize(i); 117 Value dimVal = rewriter.create<ConstantIndexOp>(loc, dim); 118 dimValues.push_back(dimVal); 119 } 120 } 121 122 // Materialize extent tensor. 123 Value staticExtentTensor = 124 rewriter.create<TensorFromElementsOp>(loc, dimValues); 125 rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor, 126 op.getType()); 127 return success(); 128 } 129 130 namespace { 131 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 132 public: 133 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 134 135 LogicalResult 136 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 137 ConversionPatternRewriter &rewriter) const override; 138 }; 139 } // namespace 140 141 LogicalResult ConstShapeOpConverter::matchAndRewrite( 142 ConstShapeOp op, ArrayRef<Value> operands, 143 ConversionPatternRewriter &rewriter) const { 144 145 // For now, this lowering supports only extent tensors, not `shape.shape` 146 // types. 147 if (op.getType().isa<ShapeType>()) 148 return failure(); 149 150 auto loc = op.getLoc(); 151 SmallVector<Value, 4> extentOperands; 152 for (auto extent : op.shape()) { 153 extentOperands.push_back( 154 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 155 } 156 Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands); 157 Type indexTy = rewriter.getIndexType(); 158 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 159 rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy); 160 return success(); 161 } 162 163 namespace { 164 class ToExtentTensorOpConversion 165 : public OpConversionPattern<ToExtentTensorOp> { 166 public: 167 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 168 169 LogicalResult 170 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 171 ConversionPatternRewriter &rewriter) const override { 172 ToExtentTensorOpAdaptor adaptor(operands); 173 174 if (!adaptor.input().getType().isa<RankedTensorType>()) 175 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 176 177 rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(), 178 op.getType()); 179 return success(); 180 } 181 }; 182 } // namespace 183 184 namespace { 185 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 186 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 187 188 LogicalResult 189 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 190 ConversionPatternRewriter &rewriter) const override; 191 }; 192 } // namespace 193 194 LogicalResult GetExtentOpConverter::matchAndRewrite( 195 GetExtentOp op, ArrayRef<Value> operands, 196 ConversionPatternRewriter &rewriter) const { 197 GetExtentOp::Adaptor transformed(operands); 198 199 // For now, only error-free types are supported by this lowering. 200 if (op.getType().isa<SizeType>()) 201 return failure(); 202 203 // Derive shape extent directly from shape origin if possible. This 204 // circumvents the necessity to materialize the shape in memory. 205 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 206 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 207 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), 208 transformed.dim()); 209 return success(); 210 } 211 } 212 213 rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(), 214 transformed.shape(), 215 ValueRange{transformed.dim()}); 216 return success(); 217 } 218 219 namespace { 220 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 221 public: 222 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 223 224 LogicalResult 225 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 226 ConversionPatternRewriter &rewriter) const override; 227 }; 228 } // namespace 229 230 LogicalResult 231 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 232 ConversionPatternRewriter &rewriter) const { 233 // For now, this lowering supports only error-free types. 234 if (op.getType().isa<SizeType>()) 235 return failure(); 236 237 shape::RankOp::Adaptor transformed(operands); 238 rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0); 239 return success(); 240 } 241 242 namespace { 243 /// Conversion pass. 244 class ConvertShapeToStandardPass 245 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 246 247 void runOnOperation() override; 248 }; 249 } // namespace 250 251 void ConvertShapeToStandardPass::runOnOperation() { 252 // Setup target legality. 253 MLIRContext &ctx = getContext(); 254 ConversionTarget target(ctx); 255 target.addLegalDialect<StandardOpsDialect>(); 256 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>(); 257 258 // Setup conversion patterns. 259 OwningRewritePatternList patterns; 260 populateShapeToStandardConversionPatterns(patterns, &ctx); 261 262 // Apply conversion. 263 auto module = getOperation(); 264 if (failed(applyPartialConversion(module, target, patterns))) 265 signalPassFailure(); 266 } 267 268 void mlir::populateShapeToStandardConversionPatterns( 269 OwningRewritePatternList &patterns, MLIRContext *ctx) { 270 // clang-format off 271 patterns.insert< 272 AnyOpConversion, 273 BinaryOpConversion<AddOp, AddIOp>, 274 ConstShapeOpConverter, 275 BinaryOpConversion<MulOp, MulIOp>, 276 ConstSizeOpConversion, 277 GetExtentOpConverter, 278 RankOpConverter, 279 ShapeOfOpConversion, 280 ToExtentTensorOpConversion>(ctx); 281 // clang-format on 282 } 283 284 std::unique_ptr<OperationPass<ModuleOp>> 285 mlir::createConvertShapeToStandardPass() { 286 return std::make_unique<ConvertShapeToStandardPass>(); 287 } 288