1 //===- TransformDialect.cpp - Transform dialect operations ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformOps.h" 10 #include "mlir/IR/Builders.h" 11 12 #include "mlir/IR/OpImplementation.h" 13 14 using namespace mlir; 15 16 #define GET_OP_CLASSES 17 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" 18 19 LogicalResult transform::SequenceOp::apply(transform::TransformResults &results, 20 transform::TransformState &state) { 21 SmallVector<Operation *> targets; 22 if (getRoot()) 23 llvm::append_range(targets, state.getPayloadOps(getRoot())); 24 else 25 targets.push_back(state.getTopLevel()); 26 27 // Map the entry block argument to the list of operations. 28 auto scope = state.make_region_scope(*getBodyBlock()->getParent()); 29 if (failed(state.mapBlockArguments(getBodyBlock()->getArgument(0), targets))) 30 return failure(); 31 32 // Apply the sequenced ops one by one. 33 for (Operation &transform : getBodyBlock()->without_terminator()) 34 if (failed(state.applyTransform(cast<TransformOpInterface>(transform)))) 35 return failure(); 36 37 // Forward the operation mapping for values yielded from the sequence to the 38 // values produced by the sequence op. 39 for (const auto &pair : 40 llvm::zip(getBodyBlock()->getTerminator()->getOperands(), 41 getOperation()->getOpResults())) { 42 Value terminatorOperand = std::get<0>(pair); 43 OpResult result = std::get<1>(pair); 44 results.set(result, state.getPayloadOps(terminatorOperand)); 45 } 46 47 return success(); 48 } 49 50 LogicalResult transform::SequenceOp::verify() { 51 if (getBodyBlock()->getNumArguments() != 1 || 52 !getBodyBlock()->getArgumentTypes()[0].isa<pdl::OperationType>()) { 53 return emitOpError() 54 << "expected the entry block to have one argument of type " 55 << pdl::OperationType::get(getContext()); 56 } 57 58 if (auto parent = getOperation()->getParentOfType<transform::SequenceOp>()) { 59 if (!getRoot()) { 60 InFlightDiagnostic diag = 61 emitOpError() 62 << "expected the root operation to be provided for a nested sequence"; 63 diag.attachNote(parent.getLoc()) << "nested in another sequence"; 64 return diag; 65 } 66 } 67 68 for (Operation &child : *getBodyBlock()) { 69 if (!isa<TransformOpInterface>(child) && 70 &child != &getBodyBlock()->back()) { 71 InFlightDiagnostic diag = 72 emitOpError() 73 << "expected children ops to implement TransformOpInterface"; 74 diag.attachNote(child.getLoc()) << "op without interface"; 75 return diag; 76 } 77 78 for (OpResult result : child.getResults()) { 79 if (llvm::hasNItemsOrLess(result.getUses(), 1)) 80 continue; 81 InFlightDiagnostic diag = child.emitError() 82 << "result #" << result.getResultNumber() 83 << " has more than one use"; 84 for (OpOperand &use : result.getUses()) { 85 diag.attachNote(use.getOwner()->getLoc()) 86 << "used here as operand #" << use.getOperandNumber(); 87 } 88 return diag; 89 } 90 } 91 92 if (getBodyBlock()->getTerminator()->getOperandTypes() != 93 getOperation()->getResultTypes()) { 94 InFlightDiagnostic diag = emitOpError() 95 << "expects the types of the terminator operands " 96 "to match the types of the result"; 97 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; 98 return diag; 99 } 100 return success(); 101 } 102