1 //====----- Bufferize.cpp - Bufferization of shape ops ---------*- C++-*--===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 #include "mlir/Dialect/Shape/Transforms/Passes.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/IR/Operation.h" 15 #include "mlir/IR/StandardTypes.h" 16 #include "mlir/Pass/Pass.h" 17 18 using namespace mlir; 19 using namespace mlir::shape; 20 21 namespace { 22 // Propagate tensor to memref conversions through shape.assuming ops. 23 class TypeConversionAssumingOpConverter 24 : public BufferizeOpConversionPattern<shape::AssumingOp> { 25 public: 26 using BufferizeOpConversionPattern< 27 shape::AssumingOp>::BufferizeOpConversionPattern; 28 29 LogicalResult 30 matchAndRewrite(shape::AssumingOp assumingOp, ArrayRef<Value> operands, 31 ConversionPatternRewriter &rewriter) const final { 32 SmallVector<Type, 2> newResultTypes; 33 newResultTypes.reserve(assumingOp.getNumResults()); 34 for (auto result : assumingOp.getResults()) { 35 auto originalType = result.getType(); 36 Type convertedType = converter.convertType(originalType); 37 newResultTypes.push_back(convertedType); 38 } 39 40 auto newAssumingOp = rewriter.create<shape::AssumingOp>( 41 assumingOp.getLoc(), newResultTypes, assumingOp.witness()); 42 43 rewriter.replaceOp(assumingOp, newAssumingOp.getResults()); 44 rewriter.inlineRegionBefore(assumingOp.doRegion(), newAssumingOp.doRegion(), 45 newAssumingOp.doRegion().end()); 46 47 return success(); 48 } 49 }; 50 51 struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> { 52 void runOnFunction() override { 53 MLIRContext &ctx = getContext(); 54 55 OwningRewritePatternList patterns; 56 BufferizeTypeConverter converter; 57 populateShapeTypeConversionPatterns(&ctx, converter, patterns); 58 59 ConversionTarget target(getContext()); 60 auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); }; 61 62 target.addDynamicallyLegalOp<AssumingOp>([&](shape::AssumingOp op) { 63 return std::all_of(op.result_type_begin(), op.result_type_end(), 64 isMemRefType); 65 }); 66 67 if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) 68 signalPassFailure(); 69 } 70 }; 71 72 } // namespace 73 74 /// Populates `patterns` with the conversion patterns of tensor->memref. 75 // 76 // TODO: Change this to work generally with any type conversions. 77 void mlir::populateShapeTypeConversionPatterns( 78 MLIRContext *context, BufferizeTypeConverter &converter, 79 OwningRewritePatternList &patterns) { 80 patterns.insert<TypeConversionAssumingOpConverter>(context, converter); 81 } 82 83 //===----------------------------------------------------------------------===// 84 // ShapeBufferizePass construction 85 //===----------------------------------------------------------------------===// 86 87 std::unique_ptr<FunctionPass> mlir::createShapeBufferizePass() { 88 return std::make_unique<ShapeBufferizePass>(); 89 } 90