1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 #include "mlir/IR/Builders.h"
11 
12 #include "mlir/IR/OpImplementation.h"
13 
14 using namespace mlir;
15 
16 #define GET_OP_CLASSES
17 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
18 
19 LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
20                                            transform::TransformState &state) {
21   SmallVector<Operation *> targets;
22   if (getRoot())
23     llvm::append_range(targets, state.getPayloadOps(getRoot()));
24   else
25     targets.push_back(state.getTopLevel());
26 
27   // Map the entry block argument to the list of operations.
28   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
29   if (failed(state.mapBlockArguments(getBodyBlock()->getArgument(0), targets)))
30     return failure();
31 
32   // Apply the sequenced ops one by one.
33   for (Operation &transform : getBodyBlock()->without_terminator())
34     if (failed(state.applyTransform(cast<TransformOpInterface>(transform))))
35       return failure();
36 
37   // Forward the operation mapping for values yielded from the sequence to the
38   // values produced by the sequence op.
39   for (const auto &pair :
40        llvm::zip(getBodyBlock()->getTerminator()->getOperands(),
41                  getOperation()->getOpResults())) {
42     Value terminatorOperand = std::get<0>(pair);
43     OpResult result = std::get<1>(pair);
44     results.set(result, state.getPayloadOps(terminatorOperand));
45   }
46 
47   return success();
48 }
49 
50 LogicalResult transform::SequenceOp::verify() {
51   if (getBodyBlock()->getNumArguments() != 1 ||
52       !getBodyBlock()->getArgumentTypes()[0].isa<pdl::OperationType>()) {
53     return emitOpError()
54            << "expected the entry block to have one argument of type "
55            << pdl::OperationType::get(getContext());
56   }
57 
58   if (auto parent = getOperation()->getParentOfType<transform::SequenceOp>()) {
59     if (!getRoot()) {
60       InFlightDiagnostic diag =
61           emitOpError()
62           << "expected the root operation to be provided for a nested sequence";
63       diag.attachNote(parent.getLoc()) << "nested in another sequence";
64       return diag;
65     }
66   }
67 
68   for (Operation &child : *getBodyBlock()) {
69     if (!isa<TransformOpInterface>(child) &&
70         &child != &getBodyBlock()->back()) {
71       InFlightDiagnostic diag =
72           emitOpError()
73           << "expected children ops to implement TransformOpInterface";
74       diag.attachNote(child.getLoc()) << "op without interface";
75       return diag;
76     }
77 
78     for (OpResult result : child.getResults()) {
79       if (llvm::hasNItemsOrLess(result.getUses(), 1))
80         continue;
81       InFlightDiagnostic diag = child.emitError()
82                                 << "result #" << result.getResultNumber()
83                                 << " has more than one use";
84       for (OpOperand &use : result.getUses()) {
85         diag.attachNote(use.getOwner()->getLoc())
86             << "used here as operand #" << use.getOperandNumber();
87       }
88       return diag;
89     }
90   }
91 
92   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
93       getOperation()->getResultTypes()) {
94     InFlightDiagnostic diag = emitOpError()
95                               << "expects the types of the terminator operands "
96                                  "to match the types of the result";
97     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
98     return diag;
99   }
100   return success();
101 }
102