1 //===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Pass/PassRegistry.h"
18 
19 using namespace mlir;
20 
21 namespace {
22 class ConvertCstrBroadcastableOp
23     : public OpRewritePattern<shape::CstrBroadcastableOp> {
24 public:
25   using OpRewritePattern::OpRewritePattern;
26   LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
27                                 PatternRewriter &rewriter) const override {
28     if (op.getType().isa<shape::ShapeType>() ||
29         op.lhs().getType().isa<shape::ShapeType>() ||
30         op.rhs().getType().isa<shape::ShapeType>()) {
31       return rewriter.notifyMatchFailure(
32           op, "cannot convert error-propagating shapes");
33     }
34 
35     auto loc = op.getLoc();
36     Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
37     Value one = rewriter.create<ConstantIndexOp>(loc, 1);
38 
39     // Find smaller and greater rank and extent tensor.
40     Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
41     Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
42     Value lhsSmaller =
43         rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
44     Type indexTy = rewriter.getIndexType();
45     Type extentTensorTy = op.lhs().getType();
46     auto ifOp = rewriter.create<scf::IfOp>(
47         loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
48         lhsSmaller,
49         [&](OpBuilder &b, Location loc) {
50           b.create<scf::YieldOp>(
51               loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()});
52         },
53         [&](OpBuilder &b, Location loc) {
54           b.create<scf::YieldOp>(
55               loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()});
56         });
57     Value lesserRank = ifOp.getResult(0);
58     Value lesserRankOperand = ifOp.getResult(1);
59     Value greaterRank = ifOp.getResult(2);
60     Value greaterRankOperand = ifOp.getResult(3);
61 
62     Value rankDiff =
63         rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
64 
65     // Generate code to compare the shapes extent by extent, and emit errors for
66     // non-broadcast-compatible shapes.
67     // Two extents are broadcast-compatible if
68     // 1. they are both equal, or
69     // 2. at least one of them is 1.
70 
71     rewriter.create<scf::ForOp>(
72         loc, rankDiff, greaterRank, one, llvm::None,
73         [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
74           Value greaterRankOperandExtent = b.create<ExtractElementOp>(
75               loc, greaterRankOperand, ValueRange{iv});
76           Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
77           Value lesserRankOperandExtent = b.create<ExtractElementOp>(
78               loc, lesserRankOperand, ValueRange{ivShifted});
79 
80           Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
81               loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
82           Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
83               loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
84           Value extentsAgree =
85               b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
86                                lesserRankOperandExtent);
87           auto broadcastIsValid =
88               b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
89                              b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
90                                             lesserRankOperandExtentIsOne));
91           b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
92           b.create<scf::YieldOp>(loc);
93         });
94 
95     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
96     return success();
97   }
98 };
99 } // namespace
100 
101 namespace {
102 class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
103 public:
104   using OpRewritePattern::OpRewritePattern;
105   LogicalResult matchAndRewrite(shape::CstrRequireOp op,
106                                 PatternRewriter &rewriter) const override {
107     rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr());
108     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
109     return success();
110   }
111 };
112 } // namespace
113 
114 void mlir::populateConvertShapeConstraintsConversionPatterns(
115     OwningRewritePatternList &patterns, MLIRContext *ctx) {
116   patterns.insert<ConvertCstrBroadcastableOp>(ctx);
117   patterns.insert<ConvertCstrRequireOp>(ctx);
118 }
119 
120 namespace {
121 // This pass eliminates shape constraints from the program, converting them to
122 // eager (side-effecting) error handling code. After eager error handling code
123 // is emitted, witnesses are satisfied, so they are replace with
124 // `shape.const_witness true`.
125 class ConvertShapeConstraints
126     : public ConvertShapeConstraintsBase<ConvertShapeConstraints> {
127   void runOnOperation() override {
128     auto func = getOperation();
129     auto *context = &getContext();
130 
131     OwningRewritePatternList patterns;
132     populateConvertShapeConstraintsConversionPatterns(patterns, context);
133 
134     if (failed(applyPatternsAndFoldGreedily(func, patterns)))
135       return signalPassFailure();
136   }
137 };
138 } // namespace
139 
140 std::unique_ptr<OperationPass<FuncOp>>
141 mlir::createConvertShapeConstraintsPass() {
142   return std::make_unique<ConvertShapeConstraints>();
143 }
144