1a2004c34SDenis Khalikov //===- RewriteInsertsPass.cpp - MLIR conversion pass ----------------------===//
2a2004c34SDenis Khalikov //
3a2004c34SDenis Khalikov // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a2004c34SDenis Khalikov // See https://llvm.org/LICENSE.txt for license information.
5a2004c34SDenis Khalikov // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a2004c34SDenis Khalikov //
7a2004c34SDenis Khalikov //===----------------------------------------------------------------------===//
8a2004c34SDenis Khalikov //
9a2004c34SDenis Khalikov // This file implements a pass to rewrite sequential chains of
10a2004c34SDenis Khalikov // `spirv::CompositeInsert` operations into `spirv::CompositeConstruct`
11a2004c34SDenis Khalikov // operations.
12a2004c34SDenis Khalikov //
13a2004c34SDenis Khalikov //===----------------------------------------------------------------------===//
14a2004c34SDenis Khalikov 
15a2004c34SDenis Khalikov #include "PassDetail.h"
1601178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1701178654SLei Zhang #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
18a2004c34SDenis Khalikov #include "mlir/IR/Builders.h"
1965fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
20a2004c34SDenis Khalikov 
21a2004c34SDenis Khalikov using namespace mlir;
22a2004c34SDenis Khalikov 
23a2004c34SDenis Khalikov namespace {
24a2004c34SDenis Khalikov 
25a2004c34SDenis Khalikov /// Replaces sequential chains of `spirv::CompositeInsertOp` operation into
26a2004c34SDenis Khalikov /// `spirv::CompositeConstructOp` operation if possible.
27a2004c34SDenis Khalikov class RewriteInsertsPass
28a2004c34SDenis Khalikov     : public SPIRVRewriteInsertsPassBase<RewriteInsertsPass> {
29a2004c34SDenis Khalikov public:
30a2004c34SDenis Khalikov   void runOnOperation() override;
31a2004c34SDenis Khalikov 
32a2004c34SDenis Khalikov private:
33a2004c34SDenis Khalikov   /// Collects a sequential insertion chain by the given
34a2004c34SDenis Khalikov   /// `spirv::CompositeInsertOp` operation, if the given operation is the last
35a2004c34SDenis Khalikov   /// in the chain.
36a2004c34SDenis Khalikov   LogicalResult
37a2004c34SDenis Khalikov   collectInsertionChain(spirv::CompositeInsertOp op,
38a2004c34SDenis Khalikov                         SmallVectorImpl<spirv::CompositeInsertOp> &insertions);
39a2004c34SDenis Khalikov };
40a2004c34SDenis Khalikov 
41*be0a7e9fSMehdi Amini } // namespace
42a2004c34SDenis Khalikov 
runOnOperation()43a2004c34SDenis Khalikov void RewriteInsertsPass::runOnOperation() {
44a2004c34SDenis Khalikov   SmallVector<SmallVector<spirv::CompositeInsertOp, 4>, 4> workList;
45a2004c34SDenis Khalikov   getOperation().walk([this, &workList](spirv::CompositeInsertOp op) {
46a2004c34SDenis Khalikov     SmallVector<spirv::CompositeInsertOp, 4> insertions;
47a2004c34SDenis Khalikov     if (succeeded(collectInsertionChain(op, insertions)))
48a2004c34SDenis Khalikov       workList.push_back(insertions);
49a2004c34SDenis Khalikov   });
50a2004c34SDenis Khalikov 
51a2004c34SDenis Khalikov   for (const auto &insertions : workList) {
52a2004c34SDenis Khalikov     auto lastCompositeInsertOp = insertions.back();
53a2004c34SDenis Khalikov     auto compositeType = lastCompositeInsertOp.getType();
54a2004c34SDenis Khalikov     auto location = lastCompositeInsertOp.getLoc();
55a2004c34SDenis Khalikov 
56a2004c34SDenis Khalikov     SmallVector<Value, 4> operands;
57a2004c34SDenis Khalikov     // Collect inserted objects.
58a2004c34SDenis Khalikov     for (auto insertionOp : insertions)
59a2004c34SDenis Khalikov       operands.push_back(insertionOp.object());
60a2004c34SDenis Khalikov 
61a2004c34SDenis Khalikov     OpBuilder builder(lastCompositeInsertOp);
62a2004c34SDenis Khalikov     auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
63a2004c34SDenis Khalikov         location, compositeType, operands);
64a2004c34SDenis Khalikov 
65a2004c34SDenis Khalikov     lastCompositeInsertOp.replaceAllUsesWith(
66c4a04059SChristian Sigg         compositeConstructOp->getResult(0));
67a2004c34SDenis Khalikov 
68a2004c34SDenis Khalikov     // Erase ops.
69a2004c34SDenis Khalikov     for (auto insertOp : llvm::reverse(insertions)) {
70a2004c34SDenis Khalikov       auto *op = insertOp.getOperation();
71a2004c34SDenis Khalikov       if (op->use_empty())
72a2004c34SDenis Khalikov         insertOp.erase();
73a2004c34SDenis Khalikov     }
74a2004c34SDenis Khalikov   }
75a2004c34SDenis Khalikov }
76a2004c34SDenis Khalikov 
collectInsertionChain(spirv::CompositeInsertOp op,SmallVectorImpl<spirv::CompositeInsertOp> & insertions)77a2004c34SDenis Khalikov LogicalResult RewriteInsertsPass::collectInsertionChain(
78a2004c34SDenis Khalikov     spirv::CompositeInsertOp op,
79a2004c34SDenis Khalikov     SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
80a2004c34SDenis Khalikov   auto indicesArrayAttr = op.indices().cast<ArrayAttr>();
81a2004c34SDenis Khalikov   // TODO: handle nested composite object.
82a2004c34SDenis Khalikov   if (indicesArrayAttr.size() == 1) {
83a2004c34SDenis Khalikov     auto numElements =
84a2004c34SDenis Khalikov         op.composite().getType().cast<spirv::CompositeType>().getNumElements();
85a2004c34SDenis Khalikov 
86a2004c34SDenis Khalikov     auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt();
87a2004c34SDenis Khalikov     // Need a last index to collect a sequential chain.
88a2004c34SDenis Khalikov     if (index + 1 != numElements)
89a2004c34SDenis Khalikov       return failure();
90a2004c34SDenis Khalikov 
91a2004c34SDenis Khalikov     insertions.resize(numElements);
92a2004c34SDenis Khalikov     while (true) {
93a2004c34SDenis Khalikov       insertions[index] = op;
94a2004c34SDenis Khalikov 
95a2004c34SDenis Khalikov       if (index == 0)
96a2004c34SDenis Khalikov         return success();
97a2004c34SDenis Khalikov 
98a2004c34SDenis Khalikov       op = op.composite().getDefiningOp<spirv::CompositeInsertOp>();
99a2004c34SDenis Khalikov       if (!op)
100a2004c34SDenis Khalikov         return failure();
101a2004c34SDenis Khalikov 
102a2004c34SDenis Khalikov       --index;
103a2004c34SDenis Khalikov       indicesArrayAttr = op.indices().cast<ArrayAttr>();
104a2004c34SDenis Khalikov       if ((indicesArrayAttr.size() != 1) ||
105a2004c34SDenis Khalikov           (indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index))
106a2004c34SDenis Khalikov         return failure();
107a2004c34SDenis Khalikov     }
108a2004c34SDenis Khalikov   }
109a2004c34SDenis Khalikov   return failure();
110a2004c34SDenis Khalikov }
111a2004c34SDenis Khalikov 
112a2004c34SDenis Khalikov std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
createRewriteInsertsPass()113a2004c34SDenis Khalikov mlir::spirv::createRewriteInsertsPass() {
114a2004c34SDenis Khalikov   return std::make_unique<RewriteInsertsPass>();
115a2004c34SDenis Khalikov }
116