1 //===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Pass/PassRegistry.h"
18 
19 using namespace mlir;
20 
21 namespace {
22 class ConvertCstrBroadcastableOp
23     : public OpRewritePattern<shape::CstrBroadcastableOp> {
24 public:
25   using OpRewritePattern::OpRewritePattern;
26   LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
27                                 PatternRewriter &rewriter) const override {
28     if (op.getType().isa<shape::ShapeType>() ||
29         op.lhs().getType().isa<shape::ShapeType>() ||
30         op.rhs().getType().isa<shape::ShapeType>()) {
31       return rewriter.notifyMatchFailure(
32           op, "cannot convert error-propagating shapes");
33     }
34 
35     auto loc = op.getLoc();
36     Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
37     Value one = rewriter.create<ConstantIndexOp>(loc, 1);
38 
39     // Find smaller and greater rank and extent tensor.
40     Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
41     Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
42     Value lhsRankULE =
43         rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
44     Type indexTy = rewriter.getIndexType();
45     Value lesserRank =
46         rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
47     Value greaterRank =
48         rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
49     Value lesserRankOperand =
50         rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
51     Value greaterRankOperand =
52         rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
53 
54     Value rankDiff =
55         rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
56 
57     // Generate code to compare the shapes extent by extent, and emit errors for
58     // non-broadcast-compatible shapes.
59     // Two extents are broadcast-compatible if
60     // 1. they are both equal, or
61     // 2. at least one of them is 1.
62 
63     rewriter.create<scf::ForOp>(
64         loc, rankDiff, greaterRank, one, llvm::None,
65         [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
66           Value greaterRankOperandExtent = b.create<ExtractElementOp>(
67               loc, greaterRankOperand, ValueRange{iv});
68           Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
69           Value lesserRankOperandExtent = b.create<ExtractElementOp>(
70               loc, lesserRankOperand, ValueRange{ivShifted});
71 
72           Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
73               loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
74           Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
75               loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
76           Value extentsAgree =
77               b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
78                                lesserRankOperandExtent);
79           auto broadcastIsValid =
80               b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
81                              b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
82                                             lesserRankOperandExtentIsOne));
83           b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
84           b.create<scf::YieldOp>(loc);
85         });
86 
87     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
88     return success();
89   }
90 };
91 } // namespace
92 
93 namespace {
94 class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
95 public:
96   using OpRewritePattern::OpRewritePattern;
97   LogicalResult matchAndRewrite(shape::CstrRequireOp op,
98                                 PatternRewriter &rewriter) const override {
99     rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr());
100     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
101     return success();
102   }
103 };
104 } // namespace
105 
106 void mlir::populateConvertShapeConstraintsConversionPatterns(
107     OwningRewritePatternList &patterns, MLIRContext *ctx) {
108   patterns.insert<ConvertCstrBroadcastableOp>(ctx);
109   patterns.insert<ConvertCstrRequireOp>(ctx);
110 }
111 
112 namespace {
113 // This pass eliminates shape constraints from the program, converting them to
114 // eager (side-effecting) error handling code. After eager error handling code
115 // is emitted, witnesses are satisfied, so they are replace with
116 // `shape.const_witness true`.
117 class ConvertShapeConstraints
118     : public ConvertShapeConstraintsBase<ConvertShapeConstraints> {
119   void runOnOperation() override {
120     auto func = getOperation();
121     auto *context = &getContext();
122 
123     OwningRewritePatternList patterns;
124     populateConvertShapeConstraintsConversionPatterns(patterns, context);
125 
126     if (failed(applyPatternsAndFoldGreedily(func, patterns)))
127       return signalPassFailure();
128   }
129 };
130 } // namespace
131 
132 std::unique_ptr<OperationPass<FuncOp>>
133 mlir::createConvertShapeConstraintsPass() {
134   return std::make_unique<ConvertShapeConstraints>();
135 }
136