1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 using namespace mlir; 18 using namespace mlir::shape; 19 20 /// Conversion patterns. 21 namespace { 22 class AnyOpConversion : public OpConversionPattern<AnyOp> { 23 public: 24 using OpConversionPattern<AnyOp>::OpConversionPattern; 25 26 LogicalResult 27 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 28 ConversionPatternRewriter &rewriter) const override; 29 }; 30 } // namespace 31 32 LogicalResult 33 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 34 ConversionPatternRewriter &rewriter) const { 35 AnyOp::Adaptor transformed(operands); 36 37 // Replace `any` with its first operand. 38 // Any operand would be a valid substitution. 39 rewriter.replaceOp(op, {transformed.inputs().front()}); 40 return success(); 41 } 42 43 namespace { 44 template <typename SrcOpTy, typename DstOpTy> 45 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 46 public: 47 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 48 49 LogicalResult 50 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 51 ConversionPatternRewriter &rewriter) const override { 52 typename SrcOpTy::Adaptor transformed(operands); 53 54 // For now, only error-free types are supported by this lowering. 55 if (op.getType().template isa<SizeType>()) 56 return failure(); 57 58 rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(), 59 transformed.rhs()); 60 return success(); 61 } 62 }; 63 } // namespace 64 65 namespace { 66 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 67 public: 68 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 69 70 LogicalResult 71 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 72 ConversionPatternRewriter &rewriter) const override { 73 74 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 75 return success(); 76 } 77 }; 78 } // namespace 79 80 namespace { 81 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 82 public: 83 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 84 85 LogicalResult 86 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 87 ConversionPatternRewriter &rewriter) const override; 88 }; 89 } // namespace 90 91 LogicalResult ShapeOfOpConversion::matchAndRewrite( 92 ShapeOfOp op, ArrayRef<Value> operands, 93 ConversionPatternRewriter &rewriter) const { 94 95 // For now, only error-free types are supported by this lowering. 96 if (op.getType().isa<ShapeType>()) 97 return failure(); 98 99 // For unranked tensors `shape_of` lowers to `scf` and the pattern can be 100 // found in the corresponding pass. 101 ShapeOfOp::Adaptor transformed(operands); 102 Value tensorVal = transformed.arg(); 103 Type tensorTy = tensorVal.getType(); 104 if (tensorTy.isa<UnrankedTensorType>()) 105 return failure(); 106 107 // Build values for individual dimensions. 108 SmallVector<Value, 8> dimValues; 109 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 110 int64_t rank = rankedTensorTy.getRank(); 111 auto loc = op.getLoc(); 112 for (int64_t i = 0; i < rank; i++) { 113 if (rankedTensorTy.isDynamicDim(i)) { 114 Value dimVal = rewriter.create<DimOp>(loc, tensorVal, i); 115 dimValues.push_back(dimVal); 116 } else { 117 int64_t dim = rankedTensorTy.getDimSize(i); 118 Value dimVal = rewriter.create<ConstantIndexOp>(loc, dim); 119 dimValues.push_back(dimVal); 120 } 121 } 122 123 // Materialize extent tensor. 124 Value staticExtentTensor = 125 rewriter.create<TensorFromElementsOp>(loc, dimValues); 126 rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor, 127 op.getType()); 128 return success(); 129 } 130 131 namespace { 132 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 133 public: 134 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 135 136 LogicalResult 137 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 138 ConversionPatternRewriter &rewriter) const override; 139 }; 140 } // namespace 141 142 LogicalResult ConstShapeOpConverter::matchAndRewrite( 143 ConstShapeOp op, ArrayRef<Value> operands, 144 ConversionPatternRewriter &rewriter) const { 145 146 // For now, this lowering supports only extent tensors, not `shape.shape` 147 // types. 148 if (op.getType().isa<ShapeType>()) 149 return failure(); 150 151 auto loc = op.getLoc(); 152 SmallVector<Value, 4> extentOperands; 153 for (auto extent : op.shape()) { 154 extentOperands.push_back( 155 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 156 } 157 Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands); 158 Type indexTy = rewriter.getIndexType(); 159 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 160 rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy); 161 return success(); 162 } 163 164 namespace { 165 class ToExtentTensorOpConversion 166 : public OpConversionPattern<ToExtentTensorOp> { 167 public: 168 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 169 170 LogicalResult 171 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 172 ConversionPatternRewriter &rewriter) const override { 173 ToExtentTensorOpAdaptor adaptor(operands); 174 175 if (!adaptor.input().getType().isa<RankedTensorType>()) 176 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 177 178 rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(), 179 op.getType()); 180 return success(); 181 } 182 }; 183 } // namespace 184 185 namespace { 186 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 187 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 188 189 LogicalResult 190 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 191 ConversionPatternRewriter &rewriter) const override; 192 }; 193 } // namespace 194 195 LogicalResult GetExtentOpConverter::matchAndRewrite( 196 GetExtentOp op, ArrayRef<Value> operands, 197 ConversionPatternRewriter &rewriter) const { 198 GetExtentOp::Adaptor transformed(operands); 199 200 // For now, only error-free types are supported by this lowering. 201 if (op.getType().isa<SizeType>()) 202 return failure(); 203 204 // Derive shape extent directly from shape origin if possible. This 205 // circumvents the necessity to materialize the shape in memory. 206 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 207 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 208 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), 209 transformed.dim()); 210 return success(); 211 } 212 } 213 214 rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(), 215 transformed.shape(), 216 ValueRange{transformed.dim()}); 217 return success(); 218 } 219 220 namespace { 221 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 222 public: 223 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 224 225 LogicalResult 226 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 227 ConversionPatternRewriter &rewriter) const override; 228 }; 229 } // namespace 230 231 LogicalResult 232 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 233 ConversionPatternRewriter &rewriter) const { 234 // For now, this lowering supports only error-free types. 235 if (op.getType().isa<SizeType>()) 236 return failure(); 237 238 shape::RankOp::Adaptor transformed(operands); 239 rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0); 240 return success(); 241 } 242 243 namespace { 244 /// Conversion pass. 245 class ConvertShapeToStandardPass 246 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 247 248 void runOnOperation() override; 249 }; 250 } // namespace 251 252 void ConvertShapeToStandardPass::runOnOperation() { 253 // Setup target legality. 254 MLIRContext &ctx = getContext(); 255 ConversionTarget target(ctx); 256 target.addLegalDialect<StandardOpsDialect>(); 257 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>(); 258 259 // Setup conversion patterns. 260 OwningRewritePatternList patterns; 261 populateShapeToStandardConversionPatterns(patterns, &ctx); 262 263 // Apply conversion. 264 auto module = getOperation(); 265 if (failed(applyPartialConversion(module, target, patterns))) 266 signalPassFailure(); 267 } 268 269 void mlir::populateShapeToStandardConversionPatterns( 270 OwningRewritePatternList &patterns, MLIRContext *ctx) { 271 // clang-format off 272 patterns.insert< 273 AnyOpConversion, 274 BinaryOpConversion<AddOp, AddIOp>, 275 ConstShapeOpConverter, 276 BinaryOpConversion<MulOp, MulIOp>, 277 ConstSizeOpConversion, 278 GetExtentOpConverter, 279 RankOpConverter, 280 ShapeOfOpConversion, 281 ToExtentTensorOpConversion>(ctx); 282 // clang-format on 283 } 284 285 std::unique_ptr<OperationPass<ModuleOp>> 286 mlir::createConvertShapeToStandardPass() { 287 return std::make_unique<ConvertShapeToStandardPass>(); 288 } 289