1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 10 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/Operation.h" 13 #include "llvm/ADT/ScopeExit.h" 14 #include "llvm/ADT/SmallPtrSet.h" 15 16 using namespace mlir; 17 18 //===----------------------------------------------------------------------===// 19 // TransformState 20 //===----------------------------------------------------------------------===// 21 22 constexpr const Value transform::TransformState::kTopLevelValue; 23 24 transform::TransformState::TransformState(Region ®ion, Operation *root) 25 : topLevel(root) { 26 auto result = mappings.try_emplace(®ion); 27 assert(result.second && "the region scope is already present"); 28 (void)result; 29 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 30 regionStack.push_back(®ion); 31 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 32 } 33 34 Operation *transform::TransformState::getTopLevel() const { return topLevel; } 35 36 ArrayRef<Operation *> 37 transform::TransformState::getPayloadOps(Value value) const { 38 const TransformOpMapping &operationMapping = getMapping(value).direct; 39 auto iter = operationMapping.find(value); 40 assert(iter != operationMapping.end() && "unknown handle"); 41 return iter->getSecond(); 42 } 43 44 Value transform::TransformState::getHandleForPayloadOp(Operation *op) const { 45 for (const Mappings &mapping : llvm::make_second_range(mappings)) { 46 if (Value handle = mapping.reverse.lookup(op)) 47 return handle; 48 } 49 return Value(); 50 } 51 52 LogicalResult transform::TransformState::tryEmplaceReverseMapping( 53 Mappings &map, Operation *operation, Value handle) { 54 auto insertionResult = map.reverse.insert({operation, handle}); 55 if (!insertionResult.second) { 56 InFlightDiagnostic diag = operation->emitError() 57 << "operation tracked by two handles"; 58 diag.attachNote(handle.getLoc()) << "handle"; 59 diag.attachNote(insertionResult.first->second.getLoc()) << "handle"; 60 return diag; 61 } 62 return success(); 63 } 64 65 LogicalResult 66 transform::TransformState::setPayloadOps(Value value, 67 ArrayRef<Operation *> targets) { 68 assert(value != kTopLevelValue && 69 "attempting to reset the transformation root"); 70 71 if (value.use_empty()) 72 return success(); 73 74 // Setting new payload for the value without cleaning it first is a misuse of 75 // the API, assert here. 76 SmallVector<Operation *> storedTargets(targets.begin(), targets.end()); 77 Mappings &mappings = getMapping(value); 78 bool inserted = 79 mappings.direct.insert({value, std::move(storedTargets)}).second; 80 assert(inserted && "value is already associated with another list"); 81 (void)inserted; 82 83 // Having multiple handles to the same operation is an error in the transform 84 // expressed using the dialect and may be constructed by valid API calls from 85 // valid IR. Emit an error here. 86 for (Operation *op : targets) { 87 if (failed(tryEmplaceReverseMapping(mappings, op, value))) 88 return failure(); 89 } 90 91 return success(); 92 } 93 94 void transform::TransformState::removePayloadOps(Value value) { 95 Mappings &mappings = getMapping(value); 96 for (Operation *op : mappings.direct[value]) 97 mappings.reverse.erase(op); 98 mappings.direct.erase(value); 99 } 100 101 LogicalResult transform::TransformState::updatePayloadOps( 102 Value value, function_ref<Operation *(Operation *)> callback) { 103 Mappings &mappings = getMapping(value); 104 auto it = mappings.direct.find(value); 105 assert(it != mappings.direct.end() && "unknown handle"); 106 SmallVector<Operation *> &association = it->getSecond(); 107 SmallVector<Operation *> updated; 108 updated.reserve(association.size()); 109 110 for (Operation *op : association) { 111 mappings.reverse.erase(op); 112 if (Operation *updatedOp = callback(op)) { 113 updated.push_back(updatedOp); 114 if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value))) 115 return failure(); 116 } 117 } 118 119 std::swap(association, updated); 120 return success(); 121 } 122 123 LogicalResult 124 transform::TransformState::applyTransform(TransformOpInterface transform) { 125 transform::TransformResults results(transform->getNumResults()); 126 if (failed(transform.apply(results, *this))) 127 return failure(); 128 129 // Remove the mapping for the operand if it is consumed by the operation. This 130 // allows us to catch use-after-free with assertions later on. 131 auto memEffectInterface = 132 cast<MemoryEffectOpInterface>(transform.getOperation()); 133 SmallVector<MemoryEffects::EffectInstance, 2> effects; 134 for (Value target : transform->getOperands()) { 135 effects.clear(); 136 memEffectInterface.getEffectsOnValue(target, effects); 137 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 138 return isa<transform::TransformMappingResource>( 139 effect.getResource()) && 140 isa<MemoryEffects::Free>(effect.getEffect()); 141 })) { 142 removePayloadOps(target); 143 } 144 } 145 146 for (auto &en : llvm::enumerate(transform->getResults())) { 147 assert(en.value().getDefiningOp() == transform.getOperation() && 148 "payload IR association for a value other than the result of the " 149 "current transform op"); 150 if (failed(setPayloadOps(en.value(), results.get(en.index())))) 151 return failure(); 152 } 153 154 return success(); 155 } 156 157 //===----------------------------------------------------------------------===// 158 // TransformState::Extension 159 //===----------------------------------------------------------------------===// 160 161 transform::TransformState::Extension::~Extension() = default; 162 163 LogicalResult 164 transform::TransformState::Extension::replacePayloadOp(Operation *op, 165 Operation *replacement) { 166 return state.updatePayloadOps(state.getHandleForPayloadOp(op), 167 [&](Operation *current) { 168 return current == op ? replacement : current; 169 }); 170 } 171 172 //===----------------------------------------------------------------------===// 173 // TransformResults 174 //===----------------------------------------------------------------------===// 175 176 transform::TransformResults::TransformResults(unsigned numSegments) { 177 segments.resize(numSegments, 178 ArrayRef<Operation *>(nullptr, static_cast<size_t>(0))); 179 } 180 181 void transform::TransformResults::set(OpResult value, 182 ArrayRef<Operation *> ops) { 183 unsigned position = value.getResultNumber(); 184 assert(position < segments.size() && 185 "setting results for a non-existent handle"); 186 assert(segments[position].data() == nullptr && "results already set"); 187 unsigned start = operations.size(); 188 llvm::append_range(operations, ops); 189 segments[position] = makeArrayRef(operations).drop_front(start); 190 } 191 192 ArrayRef<Operation *> 193 transform::TransformResults::get(unsigned resultNumber) const { 194 assert(resultNumber < segments.size() && 195 "querying results for a non-existent handle"); 196 assert(segments[resultNumber].data() != nullptr && "querying unset results"); 197 return segments[resultNumber]; 198 } 199 200 //===----------------------------------------------------------------------===// 201 // Utilities for PossibleTopLevelTransformOpTrait. 202 //===----------------------------------------------------------------------===// 203 204 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( 205 TransformState &state, Operation *op) { 206 SmallVector<Operation *> targets; 207 if (op->getNumOperands() != 0) 208 llvm::append_range(targets, state.getPayloadOps(op->getOperand(0))); 209 else 210 targets.push_back(state.getTopLevel()); 211 212 return state.mapBlockArguments(op->getRegion(0).front().getArgument(0), 213 targets); 214 } 215 216 LogicalResult 217 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) { 218 // Attaching this trait without the interface is a misuse of the API, but it 219 // cannot be caught via a static_assert because interface registration is 220 // dynamic. 221 assert(isa<TransformOpInterface>(op) && 222 "should implement TransformOpInterface to have " 223 "PossibleTopLevelTransformOpTrait"); 224 225 if (op->getNumRegions() != 1) 226 return op->emitOpError() << "expects one region"; 227 228 Region *bodyRegion = &op->getRegion(0); 229 if (!llvm::hasNItems(*bodyRegion, 1)) 230 return op->emitOpError() << "expects a single-block region"; 231 232 Block *body = &bodyRegion->front(); 233 if (body->getNumArguments() != 1 || 234 !body->getArgumentTypes()[0].isa<pdl::OperationType>()) { 235 return op->emitOpError() 236 << "expects the entry block to have one argument of type " 237 << pdl::OperationType::get(op->getContext()); 238 } 239 240 if (auto *parent = 241 op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) { 242 if (op->getNumOperands() == 0) { 243 InFlightDiagnostic diag = 244 op->emitOpError() 245 << "expects the root operation to be provided for a nested op"; 246 diag.attachNote(parent->getLoc()) 247 << "nested in another possible top-level op"; 248 return diag; 249 } 250 } 251 252 return success(); 253 } 254 255 //===----------------------------------------------------------------------===// 256 // Generated interface implementation. 257 //===----------------------------------------------------------------------===// 258 259 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc" 260