1 //===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassRegistry.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 class ConvertCstrBroadcastableOp
25     : public OpRewritePattern<shape::CstrBroadcastableOp> {
26 public:
27   using OpRewritePattern::OpRewritePattern;
28   LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
29                                 PatternRewriter &rewriter) const override {
30     if (op.getType().isa<shape::ShapeType>() ||
31         op.lhs().getType().isa<shape::ShapeType>() ||
32         op.rhs().getType().isa<shape::ShapeType>()) {
33       return rewriter.notifyMatchFailure(
34           op, "cannot convert error-propagating shapes");
35     }
36 
37     auto loc = op.getLoc();
38     Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
39     Value one = rewriter.create<ConstantIndexOp>(loc, 1);
40 
41     // Find smaller and greater rank and extent tensor.
42     Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
43     Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
44     Value lhsRankULE =
45         rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
46     Type indexTy = rewriter.getIndexType();
47     Value lesserRank =
48         rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
49     Value greaterRank =
50         rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
51     Value lesserRankOperand =
52         rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
53     Value greaterRankOperand =
54         rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
55 
56     Value rankDiff =
57         rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
58 
59     // Generate code to compare the shapes extent by extent, and emit errors for
60     // non-broadcast-compatible shapes.
61     // Two extents are broadcast-compatible if
62     // 1. they are both equal, or
63     // 2. at least one of them is 1.
64 
65     rewriter.create<scf::ForOp>(
66         loc, rankDiff, greaterRank, one, llvm::None,
67         [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
68           Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
69               loc, greaterRankOperand, ValueRange{iv});
70           Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
71           Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
72               loc, lesserRankOperand, ValueRange{ivShifted});
73 
74           Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
75               loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
76           Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
77               loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
78           Value extentsAgree =
79               b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
80                                lesserRankOperandExtent);
81           auto broadcastIsValid =
82               b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
83                              b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
84                                             lesserRankOperandExtentIsOne));
85           b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
86           b.create<scf::YieldOp>(loc);
87         });
88 
89     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
90     return success();
91   }
92 };
93 } // namespace
94 
95 namespace {
96 class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
97 public:
98   using OpRewritePattern::OpRewritePattern;
99   LogicalResult matchAndRewrite(shape::CstrRequireOp op,
100                                 PatternRewriter &rewriter) const override {
101     rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr());
102     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
103     return success();
104   }
105 };
106 } // namespace
107 
108 void mlir::populateConvertShapeConstraintsConversionPatterns(
109     OwningRewritePatternList &patterns, MLIRContext *ctx) {
110   patterns.insert<ConvertCstrBroadcastableOp>(ctx);
111   patterns.insert<ConvertCstrRequireOp>(ctx);
112 }
113 
114 namespace {
115 // This pass eliminates shape constraints from the program, converting them to
116 // eager (side-effecting) error handling code. After eager error handling code
117 // is emitted, witnesses are satisfied, so they are replace with
118 // `shape.const_witness true`.
119 class ConvertShapeConstraints
120     : public ConvertShapeConstraintsBase<ConvertShapeConstraints> {
121   void runOnOperation() override {
122     auto func = getOperation();
123     auto *context = &getContext();
124 
125     OwningRewritePatternList patterns;
126     populateConvertShapeConstraintsConversionPatterns(patterns, context);
127 
128     if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
129       return signalPassFailure();
130   }
131 };
132 } // namespace
133 
134 std::unique_ptr<OperationPass<FuncOp>>
135 mlir::createConvertShapeConstraintsPass() {
136   return std::make_unique<ConvertShapeConstraints>();
137 }
138