1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 
18 using namespace mlir;
19 using namespace mlir::shape;
20 using namespace mlir::scf;
21 
22 /// Conversion patterns.
23 namespace {
24 class AnyOpConversion : public OpConversionPattern<AnyOp> {
25 public:
26   using OpConversionPattern<AnyOp>::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
30                   ConversionPatternRewriter &rewriter) const override;
31 };
32 } // namespace
33 
34 LogicalResult
35 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
36                                  ConversionPatternRewriter &rewriter) const {
37   AnyOp::Adaptor transformed(operands);
38 
39   // Replace `any` with its first operand.
40   // Any operand would be a valid substitution.
41   rewriter.replaceOp(op, {transformed.inputs().front()});
42   return success();
43 }
44 
45 namespace {
46 template <typename SrcOpTy, typename DstOpTy>
47 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
48 public:
49   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
50 
51   LogicalResult
52   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
53                   ConversionPatternRewriter &rewriter) const override {
54     typename SrcOpTy::Adaptor transformed(operands);
55 
56     // For now, only error-free types are supported by this lowering.
57     if (op.getType().template isa<SizeType>())
58       return failure();
59 
60     rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
61                                          transformed.rhs());
62     return success();
63   }
64 };
65 } // namespace
66 
67 namespace {
68 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
69   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
70 
71   LogicalResult
72   matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
73                   ConversionPatternRewriter &rewriter) const override;
74 };
75 } // namespace
76 
77 LogicalResult BroadcastOpConverter::matchAndRewrite(
78     BroadcastOp op, ArrayRef<Value> operands,
79     ConversionPatternRewriter &rewriter) const {
80   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
81   // on shapes.
82   if (op.getType().isa<ShapeType>())
83     return failure();
84 
85   assert(!op.lhs().getType().isa<ShapeType>() &&
86          !op.rhs().getType().isa<ShapeType>());
87   auto loc = op.getLoc();
88   BroadcastOp::Adaptor transformed(operands);
89   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
90   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
91 
92   // Find smaller and greater rank and extent tensor.
93   Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
94   Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
95   Value lhsRankULE =
96       rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
97   Type indexTy = rewriter.getIndexType();
98   Value lesserRank =
99       rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
100   Value greaterRank =
101       rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
102   Value lesserRankOperand =
103       rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
104   Value greaterRankOperand =
105       rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
106 
107   // Allocate stack memory for the broadcasted extent tensor.
108   Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
109   Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank});
110 
111   // Copy extents from greater operand that are not challenged.
112   Value rankDiff =
113       rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
114   rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None,
115                          [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
116                            Value extent = b.create<ExtractElementOp>(
117                                loc, greaterRankOperand, ValueRange{iv});
118                            b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
119                            b.create<scf::YieldOp>(loc);
120                          });
121 
122   // Determine remaining broadcasted extents.
123   rewriter.create<ForOp>(
124       loc, rankDiff, greaterRank, one, llvm::None,
125       [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
126         Value greaterOperandExtent =
127             b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
128         Value greaterOperandExtentIsOne =
129             b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
130         auto ifOp = b.create<IfOp>(
131             loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
132             [&](OpBuilder &b, Location loc) {
133               Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
134               Value lesserRankOperandExtent = b.create<ExtractElementOp>(
135                   loc, lesserRankOperand, ValueRange{ivShifted});
136               b.create<scf::YieldOp>(loc, lesserRankOperandExtent);
137             },
138             [&](OpBuilder &b, Location loc) {
139               b.create<scf::YieldOp>(loc, greaterOperandExtent);
140             });
141         Value extent = ifOp.getResult(0);
142         b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
143         b.create<scf::YieldOp>(loc);
144       });
145 
146   // Load broadcasted shape as an extent tensor.
147   rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem);
148   return success();
149 }
150 
151 namespace {
152 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
153 public:
154   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
155 
156   LogicalResult
157   matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
158                   ConversionPatternRewriter &rewriter) const override;
159 };
160 } // namespace
161 
162 LogicalResult ConstShapeOpConverter::matchAndRewrite(
163     ConstShapeOp op, ArrayRef<Value> operands,
164     ConversionPatternRewriter &rewriter) const {
165 
166   // For now, this lowering supports only extent tensors, not `shape.shape`
167   // types.
168   if (op.getType().isa<ShapeType>())
169     return failure();
170 
171   auto loc = op.getLoc();
172   SmallVector<Value, 4> extentOperands;
173   for (auto extent : op.shape()) {
174     extentOperands.push_back(
175         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
176   }
177   Type indexTy = rewriter.getIndexType();
178   Value tensor =
179       rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
180   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
181   rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
182   return success();
183 }
184 
185 namespace {
186 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
187 public:
188   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
189 
190   LogicalResult
191   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
192                   ConversionPatternRewriter &rewriter) const override;
193 };
194 } // namespace
195 
196 LogicalResult ConstSizeOpConversion::matchAndRewrite(
197     ConstSizeOp op, ArrayRef<Value> operands,
198     ConversionPatternRewriter &rewriter) const {
199   rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
200   return success();
201 }
202 
203 namespace {
204 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
205   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
206 
207   LogicalResult
208   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
209                   ConversionPatternRewriter &rewriter) const override;
210 };
211 } // namespace
212 
213 LogicalResult GetExtentOpConverter::matchAndRewrite(
214     GetExtentOp op, ArrayRef<Value> operands,
215     ConversionPatternRewriter &rewriter) const {
216   GetExtentOp::Adaptor transformed(operands);
217 
218   // For now, only error-free types are supported by this lowering.
219   if (op.getType().isa<SizeType>())
220     return failure();
221 
222   // Derive shape extent directly from shape origin if possible. This
223   // circumvents the necessity to materialize the shape in memory.
224   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
225     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
226       rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
227                                          transformed.dim());
228       return success();
229     }
230   }
231 
232   rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
233                                                 transformed.shape(),
234                                                 ValueRange{transformed.dim()});
235   return success();
236 }
237 
238 namespace {
239 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
240 public:
241   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
242 
243   LogicalResult
244   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
245                   ConversionPatternRewriter &rewriter) const override;
246 };
247 } // namespace
248 
249 LogicalResult
250 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
251                                  ConversionPatternRewriter &rewriter) const {
252   // For now, this lowering supports only error-free types.
253   if (op.getType().isa<SizeType>())
254     return failure();
255 
256   shape::RankOp::Adaptor transformed(operands);
257   rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
258   return success();
259 }
260 
261 namespace {
262 /// Converts `shape.reduce` to `scf.for`.
263 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
264 public:
265   using OpConversionPattern::OpConversionPattern;
266 
267   LogicalResult
268   matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
269                   ConversionPatternRewriter &rewriter) const final;
270 };
271 } // namespace
272 
273 LogicalResult
274 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
275                                    ConversionPatternRewriter &rewriter) const {
276   // For now, this lowering is only defined on `tensor<?xindex>` operands.
277   if (op.shape().getType().isa<ShapeType>())
278     return failure();
279 
280   auto loc = op.getLoc();
281   shape::ReduceOp::Adaptor transformed(operands);
282 
283   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
284   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
285   Type indexTy = rewriter.getIndexType();
286   Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
287 
288   auto loop = rewriter.create<scf::ForOp>(
289       loc, zero, rank, one, op.initVals(),
290       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
291         Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv);
292 
293         SmallVector<Value, 2> mappedValues{iv, extent};
294         mappedValues.append(args.begin(), args.end());
295 
296         BlockAndValueMapping mapping;
297         Block *reduceBody = op.getBody();
298         mapping.map(reduceBody->getArguments(), mappedValues);
299         for (auto &nested : reduceBody->without_terminator())
300           b.clone(nested, mapping);
301 
302         SmallVector<Value, 2> mappedResults;
303         for (auto result : reduceBody->getTerminator()->getOperands())
304           mappedResults.push_back(mapping.lookup(result));
305         b.create<scf::YieldOp>(loc, mappedResults);
306       });
307 
308   rewriter.replaceOp(op, loop.getResults());
309   return success();
310 }
311 
312 namespace {
313 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
314 /// only defined on `tensor<?xindex>` operands. The test for equality first
315 /// compares their size and, if equal, checks every extent for equality.
316 ///
317 /// Example:
318 ///
319 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
320 ///
321 /// becomes
322 ///
323 /// %c0 = constant 0 : index
324 /// %0 = dim %arg0, %c0 : tensor<?xindex>
325 /// %1 = dim %arg1, %c0 : tensor<?xindex>
326 /// %2 = cmpi "eq", %0, %1 : index
327 /// %result = scf.if %2 -> (i1) {
328 ///   %c1 = constant 1 : index
329 ///   %true = constant true
330 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
331 ///     %5 = extract_element %arg0[%arg2] : tensor<?xindex>
332 ///     %6 = extract_element %arg1[%arg2] : tensor<?xindex>
333 ///     %7 = cmpi "eq", %5, %6 : index
334 ///     %8 = and %arg3, %7 : i1
335 ///     scf.yield %8 : i1
336 ///   }
337 ///   scf.yield %4 : i1
338 /// } else {
339 ///   %false = constant false
340 ///   scf.yield %false : i1
341 /// }
342 ///
343 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
344   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
345 
346   LogicalResult
347   matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
348                   ConversionPatternRewriter &rewriter) const override;
349 };
350 } // namespace
351 
352 LogicalResult
353 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
354                                     ConversionPatternRewriter &rewriter) const {
355   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
356   // on shapes.
357   if (op.lhs().getType().isa<ShapeType>() ||
358       op.rhs().getType().isa<ShapeType>()) {
359     return failure();
360   }
361 
362   ShapeEqOp::Adaptor transformed(operands);
363   auto loc = op.getLoc();
364   Type indexTy = rewriter.getIndexType();
365   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
366   Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
367   Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
368   Value eqRank =
369       rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
370   Type i1Ty = rewriter.getI1Type();
371   rewriter.replaceOpWithNewOp<IfOp>(
372       op, i1Ty, eqRank,
373       [&](OpBuilder &b, Location loc) {
374         Value one = b.create<ConstantIndexOp>(loc, 1);
375         Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
376         auto loop = b.create<scf::ForOp>(
377             loc, zero, lhsRank, one, ValueRange{init},
378             [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
379               Value conj = args[0];
380               Value lhsExtent =
381                   b.create<ExtractElementOp>(loc, transformed.lhs(), iv);
382               Value rhsExtent =
383                   b.create<ExtractElementOp>(loc, transformed.rhs(), iv);
384               Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
385                                                 lhsExtent, rhsExtent);
386               Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
387               b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
388             });
389         b.create<scf::YieldOp>(loc, loop.getResults());
390       },
391       [&](OpBuilder &b, Location loc) {
392         Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
393         b.create<scf::YieldOp>(loc, result);
394       });
395   return success();
396 }
397 
398 namespace {
399 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
400 public:
401   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
402 
403   LogicalResult
404   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
405                   ConversionPatternRewriter &rewriter) const override;
406 };
407 } // namespace
408 
409 LogicalResult ShapeOfOpConversion::matchAndRewrite(
410     ShapeOfOp op, ArrayRef<Value> operands,
411     ConversionPatternRewriter &rewriter) const {
412 
413   // For now, only error-free types are supported by this lowering.
414   if (op.getType().isa<ShapeType>())
415     return failure();
416 
417   // For ranked tensor arguments, lower to `tensor_from_elements`.
418   auto loc = op.getLoc();
419   ShapeOfOp::Adaptor transformed(operands);
420   Value tensor = transformed.arg();
421   Type tensorTy = tensor.getType();
422   if (tensorTy.isa<RankedTensorType>()) {
423 
424     // Build values for individual extents.
425     SmallVector<Value, 8> extentValues;
426     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
427     int64_t rank = rankedTensorTy.getRank();
428     for (int64_t i = 0; i < rank; i++) {
429       if (rankedTensorTy.isDynamicDim(i)) {
430         Value extent = rewriter.create<DimOp>(loc, tensor, i);
431         extentValues.push_back(extent);
432       } else {
433         Value extent =
434             rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
435         extentValues.push_back(extent);
436       }
437     }
438 
439     // Materialize extent tensor.
440     Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
441         loc, rewriter.getIndexType(), extentValues);
442     rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
443                                               op.getType());
444     return success();
445   }
446 
447   // Lower to `dynamic_tensor_from_elements` otherwise.
448   auto *ctx = rewriter.getContext();
449   Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
450   rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
451       op, getExtentTensorType(ctx), ValueRange{rank},
452       [&](OpBuilder &b, Location loc, ValueRange args) {
453         Value dim = args.front();
454         Value extent = b.create<DimOp>(loc, tensor, dim);
455         b.create<mlir::YieldOp>(loc, extent);
456       });
457 
458   return success();
459 }
460 
461 namespace {
462 class ToExtentTensorOpConversion
463     : public OpConversionPattern<ToExtentTensorOp> {
464 public:
465   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
466 
467   LogicalResult
468   matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
469                   ConversionPatternRewriter &rewriter) const override {
470     ToExtentTensorOpAdaptor adaptor(operands);
471 
472     if (!adaptor.input().getType().isa<RankedTensorType>())
473       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
474 
475     rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
476                                               op.getType());
477     return success();
478   }
479 };
480 } // namespace
481 
482 namespace {
483 /// Conversion pass.
484 class ConvertShapeToStandardPass
485     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
486 
487   void runOnOperation() override;
488 };
489 } // namespace
490 
491 void ConvertShapeToStandardPass::runOnOperation() {
492   // Setup target legality.
493   MLIRContext &ctx = getContext();
494   ConversionTarget target(ctx);
495   target.addLegalDialect<StandardOpsDialect, SCFDialect>();
496   target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
497 
498   // Setup conversion patterns.
499   OwningRewritePatternList patterns;
500   populateShapeToStandardConversionPatterns(patterns, &ctx);
501 
502   // Apply conversion.
503   auto module = getOperation();
504   if (failed(applyPartialConversion(module, target, patterns)))
505     signalPassFailure();
506 }
507 
508 void mlir::populateShapeToStandardConversionPatterns(
509     OwningRewritePatternList &patterns, MLIRContext *ctx) {
510   // clang-format off
511   patterns.insert<
512       AnyOpConversion,
513       BinaryOpConversion<AddOp, AddIOp>,
514       BinaryOpConversion<MulOp, MulIOp>,
515       BroadcastOpConverter,
516       ConstShapeOpConverter,
517       ConstSizeOpConversion,
518       GetExtentOpConverter,
519       RankOpConverter,
520       ReduceOpConverter,
521       ShapeEqOpConverter,
522       ShapeOfOpConversion,
523       ToExtentTensorOpConversion>(ctx);
524   // clang-format on
525 }
526 
527 std::unique_ptr<OperationPass<ModuleOp>>
528 mlir::createConvertShapeToStandardPass() {
529   return std::make_unique<ConvertShapeToStandardPass>();
530 }
531