1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 
18 using namespace mlir;
19 using namespace mlir::shape;
20 using namespace mlir::scf;
21 
22 /// Conversion patterns.
23 namespace {
24 class AnyOpConversion : public OpConversionPattern<AnyOp> {
25 public:
26   using OpConversionPattern<AnyOp>::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
30                   ConversionPatternRewriter &rewriter) const override;
31 };
32 } // namespace
33 
34 LogicalResult
35 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
36                                  ConversionPatternRewriter &rewriter) const {
37   AnyOp::Adaptor transformed(operands);
38 
39   // Replace `any` with its first operand.
40   // Any operand would be a valid substitution.
41   rewriter.replaceOp(op, {transformed.inputs().front()});
42   return success();
43 }
44 
45 namespace {
46 template <typename SrcOpTy, typename DstOpTy>
47 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
48 public:
49   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
50 
51   LogicalResult
52   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
53                   ConversionPatternRewriter &rewriter) const override {
54     typename SrcOpTy::Adaptor transformed(operands);
55 
56     // For now, only error-free types are supported by this lowering.
57     if (op.getType().template isa<SizeType>())
58       return failure();
59 
60     rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
61                                          transformed.rhs());
62     return success();
63   }
64 };
65 } // namespace
66 
67 namespace {
68 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
69   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
70 
71   LogicalResult
72   matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
73                   ConversionPatternRewriter &rewriter) const override;
74 };
75 } // namespace
76 
77 LogicalResult BroadcastOpConverter::matchAndRewrite(
78     BroadcastOp op, ArrayRef<Value> operands,
79     ConversionPatternRewriter &rewriter) const {
80   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
81   // on shapes.
82   if (op.getType().isa<ShapeType>())
83     return failure();
84 
85   assert(!op.lhs().getType().isa<ShapeType>() &&
86          !op.rhs().getType().isa<ShapeType>());
87   auto loc = op.getLoc();
88   BroadcastOp::Adaptor transformed(operands);
89   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
90   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
91 
92   // Find smaller and greater rank and extent tensor.
93   Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
94   Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
95   Value lhsRankULE =
96       rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
97   Type indexTy = rewriter.getIndexType();
98   Value lesserRank =
99       rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
100   Value greaterRank =
101       rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
102   auto erasedRankType =
103       RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
104   Value rankErasedLhs =
105       rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
106   Value rankErasedRhs =
107       rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
108   Value lesserRankOperand =
109       rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
110   Value greaterRankOperand =
111       rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
112 
113   // Allocate stack memory for the broadcasted extent tensor.
114   Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
115   Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank});
116 
117   // Copy extents from greater operand that are not challenged.
118   Value rankDiff =
119       rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
120   rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None,
121                          [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
122                            Value extent = b.create<ExtractElementOp>(
123                                loc, greaterRankOperand, ValueRange{iv});
124                            b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
125                            b.create<scf::YieldOp>(loc);
126                          });
127 
128   // Determine remaining broadcasted extents.
129   rewriter.create<ForOp>(
130       loc, rankDiff, greaterRank, one, llvm::None,
131       [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
132         Value greaterOperandExtent =
133             b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
134         Value greaterOperandExtentIsOne =
135             b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
136         auto ifOp = b.create<IfOp>(
137             loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
138             [&](OpBuilder &b, Location loc) {
139               Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
140               Value lesserRankOperandExtent = b.create<ExtractElementOp>(
141                   loc, lesserRankOperand, ValueRange{ivShifted});
142               b.create<scf::YieldOp>(loc, lesserRankOperandExtent);
143             },
144             [&](OpBuilder &b, Location loc) {
145               b.create<scf::YieldOp>(loc, greaterOperandExtent);
146             });
147         Value extent = ifOp.getResult(0);
148         b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
149         b.create<scf::YieldOp>(loc);
150       });
151 
152   // Load broadcasted shape as an extent tensor.
153   rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem);
154   return success();
155 }
156 
157 namespace {
158 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
159 public:
160   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
161 
162   LogicalResult
163   matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
164                   ConversionPatternRewriter &rewriter) const override;
165 };
166 } // namespace
167 
168 LogicalResult ConstShapeOpConverter::matchAndRewrite(
169     ConstShapeOp op, ArrayRef<Value> operands,
170     ConversionPatternRewriter &rewriter) const {
171 
172   // For now, this lowering supports only extent tensors, not `shape.shape`
173   // types.
174   if (op.getType().isa<ShapeType>())
175     return failure();
176 
177   auto loc = op.getLoc();
178   SmallVector<Value, 4> extentOperands;
179   for (auto extent : op.shape()) {
180     extentOperands.push_back(
181         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
182   }
183   Type indexTy = rewriter.getIndexType();
184   Value tensor =
185       rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
186   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
187   rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
188   return success();
189 }
190 
191 namespace {
192 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
193 public:
194   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
195 
196   LogicalResult
197   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
198                   ConversionPatternRewriter &rewriter) const override;
199 };
200 } // namespace
201 
202 LogicalResult ConstSizeOpConversion::matchAndRewrite(
203     ConstSizeOp op, ArrayRef<Value> operands,
204     ConversionPatternRewriter &rewriter) const {
205   rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
206   return success();
207 }
208 
209 namespace {
210 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
211   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
212 
213   LogicalResult
214   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
215                   ConversionPatternRewriter &rewriter) const override;
216 };
217 } // namespace
218 
219 LogicalResult GetExtentOpConverter::matchAndRewrite(
220     GetExtentOp op, ArrayRef<Value> operands,
221     ConversionPatternRewriter &rewriter) const {
222   GetExtentOp::Adaptor transformed(operands);
223 
224   // For now, only error-free types are supported by this lowering.
225   if (op.getType().isa<SizeType>())
226     return failure();
227 
228   // Derive shape extent directly from shape origin if possible. This
229   // circumvents the necessity to materialize the shape in memory.
230   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
231     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
232       rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
233                                          transformed.dim());
234       return success();
235     }
236   }
237 
238   rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
239                                                 transformed.shape(),
240                                                 ValueRange{transformed.dim()});
241   return success();
242 }
243 
244 namespace {
245 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
246 public:
247   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
248 
249   LogicalResult
250   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
251                   ConversionPatternRewriter &rewriter) const override;
252 };
253 } // namespace
254 
255 LogicalResult
256 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
257                                  ConversionPatternRewriter &rewriter) const {
258   // For now, this lowering supports only error-free types.
259   if (op.getType().isa<SizeType>())
260     return failure();
261 
262   shape::RankOp::Adaptor transformed(operands);
263   rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
264   return success();
265 }
266 
267 namespace {
268 /// Converts `shape.reduce` to `scf.for`.
269 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
270 public:
271   using OpConversionPattern::OpConversionPattern;
272 
273   LogicalResult
274   matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
275                   ConversionPatternRewriter &rewriter) const final;
276 };
277 } // namespace
278 
279 LogicalResult
280 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
281                                    ConversionPatternRewriter &rewriter) const {
282   // For now, this lowering is only defined on `tensor<?xindex>` operands.
283   if (op.shape().getType().isa<ShapeType>())
284     return failure();
285 
286   auto loc = op.getLoc();
287   shape::ReduceOp::Adaptor transformed(operands);
288 
289   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
290   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
291   Type indexTy = rewriter.getIndexType();
292   Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
293 
294   auto loop = rewriter.create<scf::ForOp>(
295       loc, zero, rank, one, op.initVals(),
296       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
297         Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv);
298 
299         SmallVector<Value, 2> mappedValues{iv, extent};
300         mappedValues.append(args.begin(), args.end());
301 
302         BlockAndValueMapping mapping;
303         Block *reduceBody = op.getBody();
304         mapping.map(reduceBody->getArguments(), mappedValues);
305         for (auto &nested : reduceBody->without_terminator())
306           b.clone(nested, mapping);
307 
308         SmallVector<Value, 2> mappedResults;
309         for (auto result : reduceBody->getTerminator()->getOperands())
310           mappedResults.push_back(mapping.lookup(result));
311         b.create<scf::YieldOp>(loc, mappedResults);
312       });
313 
314   rewriter.replaceOp(op, loop.getResults());
315   return success();
316 }
317 
318 namespace {
319 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
320 /// only defined on `tensor<?xindex>` operands. The test for equality first
321 /// compares their size and, if equal, checks every extent for equality.
322 ///
323 /// Example:
324 ///
325 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
326 ///
327 /// becomes
328 ///
329 /// %c0 = constant 0 : index
330 /// %0 = dim %arg0, %c0 : tensor<?xindex>
331 /// %1 = dim %arg1, %c0 : tensor<?xindex>
332 /// %2 = cmpi "eq", %0, %1 : index
333 /// %result = scf.if %2 -> (i1) {
334 ///   %c1 = constant 1 : index
335 ///   %true = constant true
336 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
337 ///     %5 = extract_element %arg0[%arg2] : tensor<?xindex>
338 ///     %6 = extract_element %arg1[%arg2] : tensor<?xindex>
339 ///     %7 = cmpi "eq", %5, %6 : index
340 ///     %8 = and %arg3, %7 : i1
341 ///     scf.yield %8 : i1
342 ///   }
343 ///   scf.yield %4 : i1
344 /// } else {
345 ///   %false = constant false
346 ///   scf.yield %false : i1
347 /// }
348 ///
349 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
350   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
351 
352   LogicalResult
353   matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
354                   ConversionPatternRewriter &rewriter) const override;
355 };
356 } // namespace
357 
358 LogicalResult
359 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
360                                     ConversionPatternRewriter &rewriter) const {
361   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
362   // on shapes.
363   if (op.lhs().getType().isa<ShapeType>() ||
364       op.rhs().getType().isa<ShapeType>()) {
365     return failure();
366   }
367 
368   ShapeEqOp::Adaptor transformed(operands);
369   auto loc = op.getLoc();
370   Type indexTy = rewriter.getIndexType();
371   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
372   Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
373   Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
374   Value eqRank =
375       rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
376   Type i1Ty = rewriter.getI1Type();
377   rewriter.replaceOpWithNewOp<IfOp>(
378       op, i1Ty, eqRank,
379       [&](OpBuilder &b, Location loc) {
380         Value one = b.create<ConstantIndexOp>(loc, 1);
381         Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
382         auto loop = b.create<scf::ForOp>(
383             loc, zero, lhsRank, one, ValueRange{init},
384             [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
385               Value conj = args[0];
386               Value lhsExtent =
387                   b.create<ExtractElementOp>(loc, transformed.lhs(), iv);
388               Value rhsExtent =
389                   b.create<ExtractElementOp>(loc, transformed.rhs(), iv);
390               Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
391                                                 lhsExtent, rhsExtent);
392               Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
393               b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
394             });
395         b.create<scf::YieldOp>(loc, loop.getResults());
396       },
397       [&](OpBuilder &b, Location loc) {
398         Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
399         b.create<scf::YieldOp>(loc, result);
400       });
401   return success();
402 }
403 
404 namespace {
405 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
406 public:
407   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
408 
409   LogicalResult
410   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
411                   ConversionPatternRewriter &rewriter) const override;
412 };
413 } // namespace
414 
415 LogicalResult ShapeOfOpConversion::matchAndRewrite(
416     ShapeOfOp op, ArrayRef<Value> operands,
417     ConversionPatternRewriter &rewriter) const {
418 
419   // For now, only error-free types are supported by this lowering.
420   if (op.getType().isa<ShapeType>())
421     return failure();
422 
423   // For ranked tensor arguments, lower to `tensor_from_elements`.
424   auto loc = op.getLoc();
425   ShapeOfOp::Adaptor transformed(operands);
426   Value tensor = transformed.arg();
427   Type tensorTy = tensor.getType();
428   if (tensorTy.isa<RankedTensorType>()) {
429 
430     // Build values for individual extents.
431     SmallVector<Value, 8> extentValues;
432     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
433     int64_t rank = rankedTensorTy.getRank();
434     for (int64_t i = 0; i < rank; i++) {
435       if (rankedTensorTy.isDynamicDim(i)) {
436         Value extent = rewriter.create<DimOp>(loc, tensor, i);
437         extentValues.push_back(extent);
438       } else {
439         Value extent =
440             rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
441         extentValues.push_back(extent);
442       }
443     }
444 
445     // Materialize extent tensor.
446     Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
447         loc, rewriter.getIndexType(), extentValues);
448     rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
449                                               op.getType());
450     return success();
451   }
452 
453   // Lower to `dynamic_tensor_from_elements` otherwise.
454   auto *ctx = rewriter.getContext();
455   Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
456   rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
457       op, getExtentTensorType(ctx), ValueRange{rank},
458       [&](OpBuilder &b, Location loc, ValueRange args) {
459         Value dim = args.front();
460         Value extent = b.create<DimOp>(loc, tensor, dim);
461         b.create<mlir::YieldOp>(loc, extent);
462       });
463 
464   return success();
465 }
466 
467 namespace {
468 class ToExtentTensorOpConversion
469     : public OpConversionPattern<ToExtentTensorOp> {
470 public:
471   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
472 
473   LogicalResult
474   matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
475                   ConversionPatternRewriter &rewriter) const override {
476     ToExtentTensorOpAdaptor adaptor(operands);
477 
478     if (!adaptor.input().getType().isa<RankedTensorType>())
479       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
480 
481     rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
482                                               op.getType());
483     return success();
484   }
485 };
486 } // namespace
487 
488 namespace {
489 /// Conversion pass.
490 class ConvertShapeToStandardPass
491     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
492 
493   void runOnOperation() override;
494 };
495 } // namespace
496 
497 void ConvertShapeToStandardPass::runOnOperation() {
498   // Setup target legality.
499   MLIRContext &ctx = getContext();
500   ConversionTarget target(ctx);
501   target.addLegalDialect<StandardOpsDialect, SCFDialect>();
502   target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
503 
504   // Setup conversion patterns.
505   OwningRewritePatternList patterns;
506   populateShapeToStandardConversionPatterns(patterns, &ctx);
507 
508   // Apply conversion.
509   auto module = getOperation();
510   if (failed(applyPartialConversion(module, target, patterns)))
511     signalPassFailure();
512 }
513 
514 void mlir::populateShapeToStandardConversionPatterns(
515     OwningRewritePatternList &patterns, MLIRContext *ctx) {
516   // clang-format off
517   patterns.insert<
518       AnyOpConversion,
519       BinaryOpConversion<AddOp, AddIOp>,
520       BinaryOpConversion<MulOp, MulIOp>,
521       BroadcastOpConverter,
522       ConstShapeOpConverter,
523       ConstSizeOpConversion,
524       GetExtentOpConverter,
525       RankOpConverter,
526       ReduceOpConverter,
527       ShapeEqOpConverter,
528       ShapeOfOpConversion,
529       ToExtentTensorOpConversion>(ctx);
530   // clang-format on
531 }
532 
533 std::unique_ptr<OperationPass<ModuleOp>>
534 mlir::createConvertShapeToStandardPass() {
535   return std::make_unique<ConvertShapeToStandardPass>();
536 }
537