1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 
18 using namespace mlir;
19 using namespace mlir::shape;
20 using namespace mlir::scf;
21 
22 /// Conversion patterns.
23 namespace {
24 class AnyOpConversion : public OpConversionPattern<AnyOp> {
25 public:
26   using OpConversionPattern<AnyOp>::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
30                   ConversionPatternRewriter &rewriter) const override;
31 };
32 } // namespace
33 
34 LogicalResult
35 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
36                                  ConversionPatternRewriter &rewriter) const {
37   AnyOp::Adaptor transformed(operands);
38 
39   // Replace `any` with its first operand.
40   // Any operand would be a valid substitution.
41   rewriter.replaceOp(op, {transformed.inputs().front()});
42   return success();
43 }
44 
45 namespace {
46 template <typename SrcOpTy, typename DstOpTy>
47 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
48 public:
49   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
50 
51   LogicalResult
52   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
53                   ConversionPatternRewriter &rewriter) const override {
54     typename SrcOpTy::Adaptor transformed(operands);
55 
56     // For now, only error-free types are supported by this lowering.
57     if (op.getType().template isa<SizeType>())
58       return failure();
59 
60     rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
61                                          transformed.rhs());
62     return success();
63   }
64 };
65 } // namespace
66 
67 namespace {
68 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
69   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
70 
71   LogicalResult
72   matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
73                   ConversionPatternRewriter &rewriter) const override;
74 };
75 } // namespace
76 
77 LogicalResult BroadcastOpConverter::matchAndRewrite(
78     BroadcastOp op, ArrayRef<Value> operands,
79     ConversionPatternRewriter &rewriter) const {
80   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
81   // on shapes.
82   if (op.getType().isa<ShapeType>())
83     return failure();
84 
85   assert(!op.lhs().getType().isa<ShapeType>() &&
86          !op.rhs().getType().isa<ShapeType>());
87   auto loc = op.getLoc();
88   BroadcastOp::Adaptor transformed(operands);
89   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
90   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
91 
92   // Find smaller and greater rank and extent tensor.
93   Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
94   Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
95   Value lhsSmaller =
96       rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
97   Type indexTy = rewriter.getIndexType();
98   Type extentTensorTy = op.getType();
99   auto ifOp = rewriter.create<IfOp>(
100       loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
101       lhsSmaller,
102       [&](OpBuilder &b, Location loc) {
103         b.create<scf::YieldOp>(loc, ValueRange{lhsRank, transformed.lhs(),
104                                                rhsRank, transformed.rhs()});
105       },
106       [&](OpBuilder &b, Location loc) {
107         b.create<scf::YieldOp>(loc, ValueRange{rhsRank, transformed.rhs(),
108                                                lhsRank, transformed.lhs()});
109       });
110   Value smallerRank = ifOp.getResult(0);
111   Value smallerOperand = ifOp.getResult(1);
112   Value greaterRank = ifOp.getResult(2);
113   Value greaterOperand = ifOp.getResult(3);
114 
115   // Allocate stack memory for the broadcasted extent tensor.
116   Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
117   Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank});
118 
119   // Copy extents from greater operand that are not challenged.
120   Value rankDiff =
121       rewriter.create<SubIOp>(loc, indexTy, greaterRank, smallerRank);
122   rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None,
123                          [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
124                            Value extent = b.create<ExtractElementOp>(
125                                loc, greaterOperand, ValueRange{iv});
126                            b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
127                            b.create<scf::YieldOp>(loc);
128                          });
129 
130   // Determine remaining broadcasted extents.
131   rewriter.create<ForOp>(
132       loc, rankDiff, greaterRank, one, llvm::None,
133       [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
134         Value greaterOperandExtent =
135             b.create<ExtractElementOp>(loc, greaterOperand, ValueRange{iv});
136         Value greaterOperandExtentIsOne =
137             b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
138         auto ifOp = b.create<IfOp>(
139             loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
140             [&](OpBuilder &b, Location loc) {
141               Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
142               Value smallerOperandExtent = b.create<ExtractElementOp>(
143                   loc, smallerOperand, ValueRange{ivShifted});
144               b.create<scf::YieldOp>(loc, smallerOperandExtent);
145             },
146             [&](OpBuilder &b, Location loc) {
147               b.create<scf::YieldOp>(loc, greaterOperandExtent);
148             });
149         Value extent = ifOp.getResult(0);
150         b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
151         b.create<scf::YieldOp>(loc);
152       });
153 
154   // Load broadcasted shape as an extent tensor.
155   rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem);
156   return success();
157 }
158 
159 namespace {
160 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
161 public:
162   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
163 
164   LogicalResult
165   matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
166                   ConversionPatternRewriter &rewriter) const override;
167 };
168 } // namespace
169 
170 LogicalResult ConstShapeOpConverter::matchAndRewrite(
171     ConstShapeOp op, ArrayRef<Value> operands,
172     ConversionPatternRewriter &rewriter) const {
173 
174   // For now, this lowering supports only extent tensors, not `shape.shape`
175   // types.
176   if (op.getType().isa<ShapeType>())
177     return failure();
178 
179   auto loc = op.getLoc();
180   SmallVector<Value, 4> extentOperands;
181   for (auto extent : op.shape()) {
182     extentOperands.push_back(
183         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
184   }
185   Type indexTy = rewriter.getIndexType();
186   Value tensor =
187       rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
188   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
189   rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
190   return success();
191 }
192 
193 namespace {
194 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
195 public:
196   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
197 
198   LogicalResult
199   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
200                   ConversionPatternRewriter &rewriter) const override;
201 };
202 } // namespace
203 
204 LogicalResult ConstSizeOpConversion::matchAndRewrite(
205     ConstSizeOp op, ArrayRef<Value> operands,
206     ConversionPatternRewriter &rewriter) const {
207   rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
208   return success();
209 }
210 
211 namespace {
212 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
213   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
214 
215   LogicalResult
216   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
217                   ConversionPatternRewriter &rewriter) const override;
218 };
219 } // namespace
220 
221 LogicalResult GetExtentOpConverter::matchAndRewrite(
222     GetExtentOp op, ArrayRef<Value> operands,
223     ConversionPatternRewriter &rewriter) const {
224   GetExtentOp::Adaptor transformed(operands);
225 
226   // For now, only error-free types are supported by this lowering.
227   if (op.getType().isa<SizeType>())
228     return failure();
229 
230   // Derive shape extent directly from shape origin if possible. This
231   // circumvents the necessity to materialize the shape in memory.
232   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
233     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
234       rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
235                                          transformed.dim());
236       return success();
237     }
238   }
239 
240   rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
241                                                 transformed.shape(),
242                                                 ValueRange{transformed.dim()});
243   return success();
244 }
245 
246 namespace {
247 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
248 public:
249   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
250 
251   LogicalResult
252   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
253                   ConversionPatternRewriter &rewriter) const override;
254 };
255 } // namespace
256 
257 LogicalResult
258 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
259                                  ConversionPatternRewriter &rewriter) const {
260   // For now, this lowering supports only error-free types.
261   if (op.getType().isa<SizeType>())
262     return failure();
263 
264   shape::RankOp::Adaptor transformed(operands);
265   rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
266   return success();
267 }
268 
269 namespace {
270 /// Converts `shape.reduce` to `scf.for`.
271 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
272 public:
273   using OpConversionPattern::OpConversionPattern;
274 
275   LogicalResult
276   matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
277                   ConversionPatternRewriter &rewriter) const final;
278 };
279 } // namespace
280 
281 LogicalResult
282 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
283                                    ConversionPatternRewriter &rewriter) const {
284   // For now, this lowering is only defined on `tensor<?xindex>` operands.
285   if (op.shape().getType().isa<ShapeType>())
286     return failure();
287 
288   auto loc = op.getLoc();
289   shape::ReduceOp::Adaptor transformed(operands);
290 
291   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
292   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
293   Type indexTy = rewriter.getIndexType();
294   Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
295 
296   auto loop = rewriter.create<scf::ForOp>(
297       loc, zero, rank, one, op.initVals(),
298       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
299         Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv);
300 
301         SmallVector<Value, 2> mappedValues{iv, extent};
302         mappedValues.append(args.begin(), args.end());
303 
304         BlockAndValueMapping mapping;
305         Block *reduceBody = op.getBody();
306         mapping.map(reduceBody->getArguments(), mappedValues);
307         for (auto &nested : reduceBody->without_terminator())
308           b.clone(nested, mapping);
309 
310         SmallVector<Value, 2> mappedResults;
311         for (auto result : reduceBody->getTerminator()->getOperands())
312           mappedResults.push_back(mapping.lookup(result));
313         b.create<scf::YieldOp>(loc, mappedResults);
314       });
315 
316   rewriter.replaceOp(op, loop.getResults());
317   return success();
318 }
319 
320 namespace {
321 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
322 /// only defined on `tensor<?xindex>` operands. The test for equality first
323 /// compares their size and, if equal, checks every extent for equality.
324 ///
325 /// Example:
326 ///
327 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
328 ///
329 /// becomes
330 ///
331 /// %c0 = constant 0 : index
332 /// %0 = dim %arg0, %c0 : tensor<?xindex>
333 /// %1 = dim %arg1, %c0 : tensor<?xindex>
334 /// %2 = cmpi "eq", %0, %1 : index
335 /// %result = scf.if %2 -> (i1) {
336 ///   %c1 = constant 1 : index
337 ///   %true = constant true
338 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
339 ///     %5 = extract_element %arg0[%arg2] : tensor<?xindex>
340 ///     %6 = extract_element %arg1[%arg2] : tensor<?xindex>
341 ///     %7 = cmpi "eq", %5, %6 : index
342 ///     %8 = and %arg3, %7 : i1
343 ///     scf.yield %8 : i1
344 ///   }
345 ///   scf.yield %4 : i1
346 /// } else {
347 ///   %false = constant false
348 ///   scf.yield %false : i1
349 /// }
350 ///
351 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
352   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
353 
354   LogicalResult
355   matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
356                   ConversionPatternRewriter &rewriter) const override;
357 };
358 } // namespace
359 
360 LogicalResult
361 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
362                                     ConversionPatternRewriter &rewriter) const {
363   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
364   // on shapes.
365   if (op.lhs().getType().isa<ShapeType>() ||
366       op.rhs().getType().isa<ShapeType>()) {
367     return failure();
368   }
369 
370   ShapeEqOp::Adaptor transformed(operands);
371   auto loc = op.getLoc();
372   Type indexTy = rewriter.getIndexType();
373   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
374   Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
375   Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
376   Value eqRank =
377       rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
378   Type i1Ty = rewriter.getI1Type();
379   rewriter.replaceOpWithNewOp<IfOp>(
380       op, i1Ty, eqRank,
381       [&](OpBuilder &b, Location loc) {
382         Value one = b.create<ConstantIndexOp>(loc, 1);
383         Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
384         auto loop = b.create<scf::ForOp>(
385             loc, zero, lhsRank, one, ValueRange{init},
386             [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
387               Value conj = args[0];
388               Value lhsExtent =
389                   b.create<ExtractElementOp>(loc, transformed.lhs(), iv);
390               Value rhsExtent =
391                   b.create<ExtractElementOp>(loc, transformed.rhs(), iv);
392               Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
393                                                 lhsExtent, rhsExtent);
394               Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
395               b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
396             });
397         b.create<scf::YieldOp>(loc, loop.getResults());
398       },
399       [&](OpBuilder &b, Location loc) {
400         Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
401         b.create<scf::YieldOp>(loc, result);
402       });
403   return success();
404 }
405 
406 namespace {
407 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
408 public:
409   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
410 
411   LogicalResult
412   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
413                   ConversionPatternRewriter &rewriter) const override;
414 };
415 } // namespace
416 
417 LogicalResult ShapeOfOpConversion::matchAndRewrite(
418     ShapeOfOp op, ArrayRef<Value> operands,
419     ConversionPatternRewriter &rewriter) const {
420 
421   // For now, only error-free types are supported by this lowering.
422   if (op.getType().isa<ShapeType>())
423     return failure();
424 
425   // For ranked tensor arguments, lower to `tensor_from_elements`.
426   auto loc = op.getLoc();
427   ShapeOfOp::Adaptor transformed(operands);
428   Value tensor = transformed.arg();
429   Type tensorTy = tensor.getType();
430   if (tensorTy.isa<RankedTensorType>()) {
431 
432     // Build values for individual extents.
433     SmallVector<Value, 8> extentValues;
434     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
435     int64_t rank = rankedTensorTy.getRank();
436     for (int64_t i = 0; i < rank; i++) {
437       if (rankedTensorTy.isDynamicDim(i)) {
438         Value extent = rewriter.create<DimOp>(loc, tensor, i);
439         extentValues.push_back(extent);
440       } else {
441         Value extent =
442             rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
443         extentValues.push_back(extent);
444       }
445     }
446 
447     // Materialize extent tensor.
448     Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
449         loc, rewriter.getIndexType(), extentValues);
450     rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
451                                               op.getType());
452     return success();
453   }
454 
455   // Lower to `dynamic_tensor_from_elements` otherwise.
456   auto *ctx = rewriter.getContext();
457   Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
458   rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
459       op, getExtentTensorType(ctx), ValueRange{rank},
460       [&](OpBuilder &b, Location loc, ValueRange args) {
461         Value dim = args.front();
462         Value extent = b.create<DimOp>(loc, tensor, dim);
463         b.create<mlir::YieldOp>(loc, extent);
464       });
465 
466   return success();
467 }
468 
469 namespace {
470 class ToExtentTensorOpConversion
471     : public OpConversionPattern<ToExtentTensorOp> {
472 public:
473   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
474 
475   LogicalResult
476   matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
477                   ConversionPatternRewriter &rewriter) const override {
478     ToExtentTensorOpAdaptor adaptor(operands);
479 
480     if (!adaptor.input().getType().isa<RankedTensorType>())
481       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
482 
483     rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
484                                               op.getType());
485     return success();
486   }
487 };
488 } // namespace
489 
490 namespace {
491 /// Conversion pass.
492 class ConvertShapeToStandardPass
493     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
494 
495   void runOnOperation() override;
496 };
497 } // namespace
498 
499 void ConvertShapeToStandardPass::runOnOperation() {
500   // Setup target legality.
501   MLIRContext &ctx = getContext();
502   ConversionTarget target(ctx);
503   target.addLegalDialect<StandardOpsDialect, SCFDialect>();
504   target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
505 
506   // Setup conversion patterns.
507   OwningRewritePatternList patterns;
508   populateShapeToStandardConversionPatterns(patterns, &ctx);
509 
510   // Apply conversion.
511   auto module = getOperation();
512   if (failed(applyPartialConversion(module, target, patterns)))
513     signalPassFailure();
514 }
515 
516 void mlir::populateShapeToStandardConversionPatterns(
517     OwningRewritePatternList &patterns, MLIRContext *ctx) {
518   // clang-format off
519   patterns.insert<
520       AnyOpConversion,
521       BinaryOpConversion<AddOp, AddIOp>,
522       BinaryOpConversion<MulOp, MulIOp>,
523       BroadcastOpConverter,
524       ConstShapeOpConverter,
525       ConstSizeOpConversion,
526       GetExtentOpConverter,
527       RankOpConverter,
528       ReduceOpConverter,
529       ShapeEqOpConverter,
530       ShapeOfOpConversion,
531       ToExtentTensorOpConversion>(ctx);
532   // clang-format on
533 }
534 
535 std::unique_ptr<OperationPass<ModuleOp>>
536 mlir::createConvertShapeToStandardPass() {
537   return std::make_unique<ConvertShapeToStandardPass>();
538 }
539