1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 10 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/Operation.h" 13 #include "llvm/ADT/ScopeExit.h" 14 #include "llvm/ADT/SmallPtrSet.h" 15 16 using namespace mlir; 17 18 //===----------------------------------------------------------------------===// 19 // TransformState 20 //===----------------------------------------------------------------------===// 21 22 constexpr const Value transform::TransformState::kTopLevelValue; 23 24 transform::TransformState::TransformState(Region ®ion, Operation *root) 25 : topLevel(root) { 26 auto result = mappings.try_emplace(®ion); 27 assert(result.second && "the region scope is already present"); 28 (void)result; 29 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 30 regionStack.push_back(®ion); 31 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 32 } 33 34 Operation *transform::TransformState::getTopLevel() const { return topLevel; } 35 36 ArrayRef<Operation *> 37 transform::TransformState::getPayloadOps(Value value) const { 38 const TransformOpMapping &operationMapping = getMapping(value).direct; 39 auto iter = operationMapping.find(value); 40 assert(iter != operationMapping.end() && "unknown handle"); 41 return iter->getSecond(); 42 } 43 44 LogicalResult 45 transform::TransformState::setPayloadOps(Value value, 46 ArrayRef<Operation *> targets) { 47 assert(value != kTopLevelValue && 48 "attempting to reset the transformation root"); 49 50 if (value.use_empty()) 51 return success(); 52 53 // Setting new payload for the value without cleaning it first is a misuse of 54 // the API, assert here. 55 SmallVector<Operation *> storedTargets(targets.begin(), targets.end()); 56 Mappings &mappings = getMapping(value); 57 bool inserted = 58 mappings.direct.insert({value, std::move(storedTargets)}).second; 59 assert(inserted && "value is already associated with another list"); 60 (void)inserted; 61 62 // Having multiple handles to the same operation is an error in the transform 63 // expressed using the dialect and may be constructed by valid API calls from 64 // valid IR. Emit an error here. 65 for (Operation *op : targets) { 66 auto insertionResult = mappings.reverse.insert({op, value}); 67 if (!insertionResult.second) { 68 InFlightDiagnostic diag = op->emitError() 69 << "operation tracked by two handles"; 70 diag.attachNote(value.getLoc()) << "handle"; 71 diag.attachNote(insertionResult.first->second.getLoc()) << "handle"; 72 return diag; 73 } 74 } 75 76 return success(); 77 } 78 79 void transform::TransformState::removePayloadOps(Value value) { 80 Mappings &mappings = getMapping(value); 81 for (Operation *op : mappings.direct[value]) 82 mappings.reverse.erase(op); 83 mappings.direct.erase(value); 84 } 85 86 void transform::TransformState::updatePayloadOps( 87 Value value, function_ref<Operation *(Operation *)> callback) { 88 auto it = getMapping(value).direct.find(value); 89 assert(it != getMapping(value).direct.end() && "unknown handle"); 90 SmallVector<Operation *> &association = it->getSecond(); 91 SmallVector<Operation *> updated; 92 updated.reserve(association.size()); 93 94 for (Operation *op : association) 95 if (Operation *updatedOp = callback(op)) 96 updated.push_back(updatedOp); 97 98 std::swap(association, updated); 99 } 100 101 LogicalResult 102 transform::TransformState::applyTransform(TransformOpInterface transform) { 103 transform::TransformResults results(transform->getNumResults()); 104 if (failed(transform.apply(results, *this))) 105 return failure(); 106 107 for (Value target : transform->getOperands()) 108 removePayloadOps(target); 109 110 for (auto &en : llvm::enumerate(transform->getResults())) { 111 assert(en.value().getDefiningOp() == transform.getOperation() && 112 "payload IR association for a value other than the result of the " 113 "current transform op"); 114 if (failed(setPayloadOps(en.value(), results.get(en.index())))) 115 return failure(); 116 } 117 118 return success(); 119 } 120 121 transform::TransformState::Extension::~Extension() = default; 122 123 //===----------------------------------------------------------------------===// 124 // TransformResults 125 //===----------------------------------------------------------------------===// 126 127 transform::TransformResults::TransformResults(unsigned numSegments) { 128 segments.resize(numSegments, 129 ArrayRef<Operation *>(nullptr, static_cast<size_t>(0))); 130 } 131 132 void transform::TransformResults::set(OpResult value, 133 ArrayRef<Operation *> ops) { 134 unsigned position = value.getResultNumber(); 135 assert(position < segments.size() && 136 "setting results for a non-existent handle"); 137 assert(segments[position].data() == nullptr && "results already set"); 138 unsigned start = operations.size(); 139 llvm::append_range(operations, ops); 140 segments[position] = makeArrayRef(operations).drop_front(start); 141 } 142 143 ArrayRef<Operation *> 144 transform::TransformResults::get(unsigned resultNumber) const { 145 assert(resultNumber < segments.size() && 146 "querying results for a non-existent handle"); 147 assert(segments[resultNumber].data() != nullptr && "querying unset results"); 148 return segments[resultNumber]; 149 } 150 151 //===----------------------------------------------------------------------===// 152 // Utilities for PossibleTopLevelTransformOpTrait. 153 //===----------------------------------------------------------------------===// 154 155 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( 156 TransformState &state, Operation *op) { 157 SmallVector<Operation *> targets; 158 if (op->getNumOperands() != 0) 159 llvm::append_range(targets, state.getPayloadOps(op->getOperand(0))); 160 else 161 targets.push_back(state.getTopLevel()); 162 163 return state.mapBlockArguments(op->getRegion(0).front().getArgument(0), 164 targets); 165 } 166 167 LogicalResult 168 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) { 169 // Attaching this trait without the interface is a misuse of the API, but it 170 // cannot be caught via a static_assert because interface registration is 171 // dynamic. 172 assert(isa<TransformOpInterface>(op) && 173 "should implement TransformOpInterface to have " 174 "PossibleTopLevelTransformOpTrait"); 175 176 if (op->getNumRegions() != 1) 177 return op->emitOpError() << "expects one region"; 178 179 Region *bodyRegion = &op->getRegion(0); 180 if (!llvm::hasNItems(*bodyRegion, 1)) 181 return op->emitOpError() << "expects a single-block region"; 182 183 Block *body = &bodyRegion->front(); 184 if (body->getNumArguments() != 1 || 185 !body->getArgumentTypes()[0].isa<pdl::OperationType>()) { 186 return op->emitOpError() 187 << "expects the entry block to have one argument of type " 188 << pdl::OperationType::get(op->getContext()); 189 } 190 191 if (auto *parent = 192 op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) { 193 if (op->getNumOperands() == 0) { 194 InFlightDiagnostic diag = 195 op->emitOpError() 196 << "expects the root operation to be provided for a nested op"; 197 diag.attachNote(parent->getLoc()) 198 << "nested in another possible top-level op"; 199 return diag; 200 } 201 } 202 203 return success(); 204 } 205 206 //===----------------------------------------------------------------------===// 207 // Generated interface implementation. 208 //===----------------------------------------------------------------------===// 209 210 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc" 211