1 //===- RewriteInsertsPass.cpp - MLIR conversion pass ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to rewrite sequential chains of 10 // `spirv::CompositeInsert` operations into `spirv::CompositeConstruct` 11 // operations. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/Passes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinOps.h" 20 21 using namespace mlir; 22 23 namespace { 24 25 /// Replaces sequential chains of `spirv::CompositeInsertOp` operation into 26 /// `spirv::CompositeConstructOp` operation if possible. 27 class RewriteInsertsPass 28 : public SPIRVRewriteInsertsPassBase<RewriteInsertsPass> { 29 public: 30 void runOnOperation() override; 31 32 private: 33 /// Collects a sequential insertion chain by the given 34 /// `spirv::CompositeInsertOp` operation, if the given operation is the last 35 /// in the chain. 36 LogicalResult 37 collectInsertionChain(spirv::CompositeInsertOp op, 38 SmallVectorImpl<spirv::CompositeInsertOp> &insertions); 39 }; 40 41 } // namespace 42 43 void RewriteInsertsPass::runOnOperation() { 44 SmallVector<SmallVector<spirv::CompositeInsertOp, 4>, 4> workList; 45 getOperation().walk([this, &workList](spirv::CompositeInsertOp op) { 46 SmallVector<spirv::CompositeInsertOp, 4> insertions; 47 if (succeeded(collectInsertionChain(op, insertions))) 48 workList.push_back(insertions); 49 }); 50 51 for (const auto &insertions : workList) { 52 auto lastCompositeInsertOp = insertions.back(); 53 auto compositeType = lastCompositeInsertOp.getType(); 54 auto location = lastCompositeInsertOp.getLoc(); 55 56 SmallVector<Value, 4> operands; 57 // Collect inserted objects. 58 for (auto insertionOp : insertions) 59 operands.push_back(insertionOp.object()); 60 61 OpBuilder builder(lastCompositeInsertOp); 62 auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>( 63 location, compositeType, operands); 64 65 lastCompositeInsertOp.replaceAllUsesWith( 66 compositeConstructOp->getResult(0)); 67 68 // Erase ops. 69 for (auto insertOp : llvm::reverse(insertions)) { 70 auto *op = insertOp.getOperation(); 71 if (op->use_empty()) 72 insertOp.erase(); 73 } 74 } 75 } 76 77 LogicalResult RewriteInsertsPass::collectInsertionChain( 78 spirv::CompositeInsertOp op, 79 SmallVectorImpl<spirv::CompositeInsertOp> &insertions) { 80 auto indicesArrayAttr = op.indices().cast<ArrayAttr>(); 81 // TODO: handle nested composite object. 82 if (indicesArrayAttr.size() == 1) { 83 auto numElements = 84 op.composite().getType().cast<spirv::CompositeType>().getNumElements(); 85 86 auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt(); 87 // Need a last index to collect a sequential chain. 88 if (index + 1 != numElements) 89 return failure(); 90 91 insertions.resize(numElements); 92 while (true) { 93 insertions[index] = op; 94 95 if (index == 0) 96 return success(); 97 98 op = op.composite().getDefiningOp<spirv::CompositeInsertOp>(); 99 if (!op) 100 return failure(); 101 102 --index; 103 indicesArrayAttr = op.indices().cast<ArrayAttr>(); 104 if ((indicesArrayAttr.size() != 1) || 105 (indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index)) 106 return failure(); 107 } 108 } 109 return failure(); 110 } 111 112 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>> 113 mlir::spirv::createRewriteInsertsPass() { 114 return std::make_unique<RewriteInsertsPass>(); 115 } 116